Inspired by @Osama_Saad_Farouk’s sharing, I would like to share my note on the Adam optimizer. Hopfully it is a good way for some learners to approach this technique.
This note is based on Kingma D. & Ba J. (2014). Adam: A Method for Stochastic Optimization.
We should know about the Adam optimizer
We have seen the Adam in Course 2 Week 1’s assignment when we compile our very first neural network model. Adam has the same position as the gradient descent algorithm that we have learnt because Adam also instructs how model parameters should be updated over the training progress. Indeed Adam is a way more commonly used optimizer that it can be the default for your baseline, so we should know about it!
How does Adam different from the gradient descent that we have learnt?
Throughtout the MLS course 1, we have been talking about this gradient descent algorithm:
w := w - \alpha \times \frac{\partial{J}}{\partial{w}}
which is the most basic version of gradient descent. I am sure you have sensed that there are more than one way of doing gradient descent - and yes, the Adam optimizer is one of the many variants that is based on gradient descent, and it has the following form:
w := w - \alpha \times SNR
where SNR stands for a pseudo Signal-to-Noise-Ratio which is a number calculated as the training progresses - so yes, the SNR value keeps changing.
What is SNR, our pseudo Signal to Noise Ratio?
We know we pass training samples to the model for training, and we don’t pass them just once but for multiple times. On top of that, in practice, we will split our samples into batches and pass one batch at a time. Therefore, if we have 128,000 samples and we plan to pass all the samples to the model 100 times at a batch size of 128, we expect for \frac{128000}{128}\times{100}=100,000 updates (or 100,000 gradient descent steps). Training a model with these mini-batches is called the mini-batch training.
The thing is, not all 100,000 steps are equally informative (significant). We should expect the additional information (gradient, or Signal) provided by the future batches be gradually smaller than the previous batches because the model has been improving and converging (aka. the gradient has been descending). Therefore, we say the Signal provided by the batches decreases over time (aka training epoches), and that is the Signal term in our SNR.
As for the Noise term, it naturally refers to the noise of the signal, which is the variation of the signals. In the context of model training, Noise is the fluctuation due to the differences between the samples from batch to batch.
Knowing that both Signal and Noise naturally exist and mixed, ideally we expect our training signal, over the training progress, be like:
so that we have weaker and weaker Signal over the epochs, but since Noise is an intrinsic nature due to the samples, it remains more or less the same strength throughout. Since Signal is decreasing but Noise is not, the noise becomes more apparent at later epochs, or we say the pseudo Signal-to-Noise Ratio (SNR) is decreasing.
Why do we want SNR?
Now we know what the SNR is about, but why do we need it? The answer is simple, we introduce SNR as a proportionality term in the Adam version of gradient descent so that the update is stronger when the Signal is stronger, and the update ceases when the Noise becomes dominant.
How do we measure SNR?
Let me get straight to it:
SNR = \frac{\hat{m}}{\sqrt{\hat{v}} + \epsilon}
where:
- \hat{m} is the bias-corrected exponential-moving-average of gradients, representing Signal
- \sqrt{\hat{v}} is square-root of the bias-corrected exponential-moving-average of squared deviation of gradients, representing Noise
- \epsilon is a small number (e.g. 10^{-8}) for numerical stability (to lower bound the denominator to be always \ge \epsilon)
Essentially, we have 4 components:
Components | Maths | Purpose |
---|---|---|
gradients | g_t = \frac{\partial{J}}{\partial{w}} | for measuring Signal |
squared deviation of gradients | g_t^2 | a pseudo way for measuring Noise |
(exponential) moving-average |
m_t = \beta_1 \times m_{t-1} + (1-\beta_1) \times g_t v_t = \beta_2 \times v_{t-1} + (1-\beta_2) \times g_t^2 |
to cumulatively calculate the average of Signal (also called the first moment) and Noise (also called the second raw moment) |
bias-correction |
\hat{m}_t = \frac{m_t}{1-\beta_1^t} \hat{v}_t = \frac{v_t}{1-\beta_2^t} |
We want to take average of the Signal and Noise to be more statistically sound. In other words, we want to stablize them with their previous measurements, and such stability can also be intrepreted as momentum because new measurements can’t change the moving averages dramatically. However, we can ONLY cumulatively calculate the averages because model training is a iterative process so we do not have all the Signal and Noise values beforehand. Therefore at the beginning of such cumulation when we have a small number of Signal or Noise samples, the averages are biased to the early samples, so we want to correct such bias. With the form on the left column, the correction is only significant in the early training stage (when t is small). As t becomes large, the denominator becomes 1 and there is no more bias-correction |
Putting them altogether, I hope you will be able to understand the Adam algorithm written in the following way by its inventors:
Source: Kingma D. & Ba J. (2014). Adam: A Method for Stochastic Optimization.
Conclusion
Unlike the gradient descent algorithm we have learnt in Course 1 that has a fixed learning rate (\alpha), Adam is an adaptive, gradient-descent-based optimization algorithm that effectively adjusts the learning rate (\alpha \times SNR) over the course of the training process with the pseudo Signal-to-Noise Ratio SNR. Adam is very commonly used for mini-batch training where Noise is more significant (because of small sample size in each batch). We want Adam to avoid Noise from “shaking” our model.
Lastly, just like \alpha in the basic gradient descent algorithm, the \alpha in Adam is also a learning rate that needs to be tuned case-by-case. However, \beta_1 and \beta_2, the controlling hyper-parameters for the moving averages and bias corrections, are 0.9 and 0.999 by default, though they are also subject to tuning. The larger a \beta value, the longer-term the moving average it is.
Cheers,
Raymond