Initializing Weights to Mitigate Vanishing/Exploding Gradients

So in the video: “Weight Initialization for Deep Networks”, I’m having a hard time understanding how he went from ‘var(w_i) = 2/n’ to ‘w[l] = np.random.randn(shape) * np.sqrt(2/n[l-1])’… I understand that if the we have a large number of units, then we want each of the weights to be smaller. But changing the variance doesn’t necessarily make the weights larger or smaller, it just makes them either tighter or farther away from each other right? so I’m confused…

1 Like

The point is that np.random.randn gives you a Gaussian (normal) distribution with \mu = 0 and \sigma = 1. So the factor that you multiply by the output of randn ends up being the standard deviation of the resulting distribution, right?

1 Like

Could you explain what is happening when we multiply the Gaussian by the standard deviation? I don’t understand

1 Like

Think about the graph of a normal (Gaussian) distribution with \mu = 0 and \sigma = 1: it’s the canonical “bell curve” symmetric about 0, right? So what happens if I multiply every value on that curve by some positive constant? It either compresses or expands the shape of the curve in the horizontal direction. Compresses if the constant is < 1 and expands if it’s > 1, of course. So what does that do to \sigma value of the distribution? It multiplies it by the same constant. This is not some deep or surprising fact: it’s just the Distributive Law in action. :nerd_face:

So the constant that we are scaling our distribution by will be inversely related to the size of our inputs. If n is larger, the constant is smaller, and will shrink our standard deviation. So how does this make the weights smaller? Couldn’t we have very large weights with a small standard deviation?

The weights are smaller only because they start out smaller if we use a very small multiplicative constant. That doesn’t mean they won’t grow larger or even very large later, if that’s what back propagation needs to do to get to a better solution.

Also notice that the constant being used here has nothing to do with the magnitude of our inputs, right? It’s only the dimension of our inputs. But maybe that’s what you meant?

The other point here is that these are not just any old random numbers, right? They are random numbers that form a Gaussian (Normal) Distribution.

Oh, you’re saying because a normal distribution has a mean of 0, when we decrease the standard deviation, we’re pushing all the weights closer to zero?

We don’t have to get as sophisticated as talking about standard deviations. I multiply all the values I start with by a number < 1, so that makes them smaller in absolute value, right?

That makes sense… I think I understand now that when we scale down or up the absolute value of the weights we are simultaneously making the data either tighter or more spread out. But would this be the case if the data was not normally distributed? what if the mean was 99 and all the datapoints were nonzero?

If the mean of your data is 99, then you’re going to need a pretty small value as the multiplier to make the values small. If you want the mean to be, say 0.01, then let’s see. Hmmm:

99 * x = 0.01

So the solution would be … wait for it …

x = 0.01 / 99 = 0.00010101 …

So if you multiply your whole dataset by that value, they will be closer together, right?

Ok, so the mean is shifted by the same amount as the std when we multiply our distribution by a constant k. When we have a normal distribution, the mean stays at zero because 0 * k = 0.

So overall, we just want our weights to be scaled inversely to the size of the inputs. Specifically so that the variance = 2/n (because someone figured that out that’s what worked). To satisfy this, we convert the variance to std by taking the sqrt. Then use that as the constant to scale our weights by because when we scale a normal distribution by that constant, the std of the new distribution becomes that constant (1 * k = k). Is this accurate?

Yes, I think that states it nicely.

Ok thank you for your help!

Maybe it’s worth also just putting a little emphasis on this part of your description:

The important “take away” being that there is no one “silver bullet” initialization method that is guaranteed to work best in all cases. This is just one method that has been discovered that does work well in a lot of cases. Like a lot of the hyperparameters (design choices one must make), the choice of initialization method requires some experimentation. It’s reasonable to start with He Initialization, but it is not guaranteed to be the best choice. The only way to find out is by experimentation.

1 Like