I could not figure out the math behind multiplying sqrt(2/n) with each WL
- How does multiplying this value prevent from exploding or vanishing?
- If I multiply sqrt(2/n), how does the variance = 2/n. I tried to calculate but couldn’t reach out to this number. Please let me know what I am missing
Thanks!
For point 1), note that the more layers you have, the more terms you have in the “chain rule” products that form the gradients. In the earlier layers of the network, you have factors from all the subsequent layers. If you multiply numbers with absolute value << 1, then the products tend towards 0 the more you multiply. If you multiply numbers with absolute value > 1, the larger the absolute value of the product gets. So the idea of the factor they add it is to reduce the magnitude of the random values the more layers you have.
For point 2), note that we are starting with a Normal (Gaussian) Distribution with \mu = 0 and \sigma = 1, right? So what happens to \sigma if you multiply all the elements of the distribution by a constant factor? Also remember that the variance is the square of the standard deviation of a Normal Distribution.
1 Like