Why is there a residual connection around each attention layer followed by a layer normalization step in the in the decoder network?
The answer is: To speed up the training, and significantly reduce the overall processing time. But how and why ?
Why is there a residual connection around each attention layer followed by a layer normalization step in the in the decoder network?
The answer is: To speed up the training, and significantly reduce the overall processing time. But how and why ?
can you share a screenshot of the residual connection you are trying to refer in relation to decoder network!!
The significance comes more with the transformer architecture where it kinds of play reminder for the network of its orginal state
@Anand_Kumar3 Hm, thatâs a good question.
Iâm not sure I can answer you 100% correctly on this, but can give you the insight as far as I see it.
Iâm not sure if you took DLS ? Well, as best as I remember it we were first introduced to residual (aka âskipâ) connections in the context of very deep networks (100 layers, etc â I believe this was the convolutions course).
In that context, without residual connections, these very deep nets ran the risk of âexploding gradientsâ. Yet these residual connections, by passing on earlier activations, have an effect of sort of âtamping downâ your gradients, sort of returning them to an order of reasonable magnitude.
In much the same way, normalization does the same thing-- it brings the gradients âback into the foldâ so speak, so at least in real (though not relative) terms, they are closer to one another.
Why might this result in a speed-up ? Well, if you think about it, though we are not working with a neural net at these early stages in the decoder, we are still working with gradients and trying to âstep throughâ, bit by bit to a final set of weights and a solution.
But if we donât try to impose any measures to keep our gradients from getting wildly out of control, weâll require increasingly more and more training steps to bring them back in line with the desired weights for the final model.
By skip and norm, we are able to keep them in check (avoid having them going way out of distribution too soon), less require less retraining, thus a faster over all training of the model.
Like I said, Iâm not quite 100% sure this is correct⌠But that is how I like to think about it.
Hope this at least helps a little.
He is referring to these @Deepti_Prasad:
This is transformer architecture, he was mentioning decoder, so wanted him to confirm on that.
@Deepti_Prasad well this is the Decoder portion of the Transformer model. Since his post is under NLP with Attention Models, I at least assumed that was the Decoder he was talking aboutâŚ
I like to confirm my doubt before responding to a learner in detail.
Just note the part I circled in red.
I know transformer architecture, let the learner confirm if his doubt is about the same. I donât like to hurry in answering when I have doubt.
sorry @Deepti_Prasad for delayed response, and thanks @Nevermnd for helping here. yes, those are same residual things pointed by @Nevermnd in red circle.
Thanks for the detailed explanation here, but Iâm still wondering how it helps to keep exploding gradient in check, also weâre anyways normalising it. And wonât adding previous state makes it a bit more similar to the initial state only.
To me it looked like you changed from 0->50 then add initial state to go back to somewhere 27 then normalise again and final use a even small number, I mean looks like efficiency tampering ? please provide your views here, I might be wrong also.
Also wanted to check from input embedding and output embedding we mean embedding for source language and right shifted embedding of target language ?
residual connection is kind of reminder of it is initial state in deep neural network and also of altered stated, so it prevents from vanishing gradient from exploding.
like the input has been given âI am hungryâ in a sequence of tokens and after the block attention, the block 1 output has come as a different output than expected, and as this output would pass through the next block 2(multi attention), this residual connection lets the block know the original input was I am hungry and than the output from the block 1.
The idea of including residual connection is more came from having independency of the neural network to understand or decoder any part of sequence and also keeping the network at check if it is actually really finding what it is looking for.
Regards
DP
@Anand_Kumar3 well, so hereâs where we enter the âIâm not entirely sureâ part:
I mean when I learned about âskipâ connections from Prof. Ng in DLS, it was in the context of very deep neural nets (and thus âexploding gradientsâ). I mean it was presumed that normalization was still being done, only this turned out to be ânot enoughâ-- Thus your skip connection.
Also, personally, Iâm not sure I would say a skip connection applies the âinitialâ state of the network-- Simply an âearlierâ one (to me âinitialâ suggests something coming from the beginning). Really a skip connection could be in the middle of the network.
However, the thing is multi-head attention is NOT A NEURAL NETWORK (but I am kind of assuming a similar concept/principal is at work). I mean we are still trying to resolve our Q,K,V matrices. Iâm not sure âexploding gradientsâ then is exactly the issue that is being resolved here by skip connections. In my mind it is something more like, ensuring the attention doesnât start shifting too far away from its target, or⌠Almost to keep Q,K,V from âseparatingâ too much.
I also think the norm in the âadd and normâ step is probably partly required because we are bringing in this previous activation.
Figure, weâre basically trying to âoptimizeâ (? sorta ? maybe ? kinda ?) to find the ideal set of weights for Q,K,V to solve our problem for inference. If we start to âdrift awayâ from whatever that âoptimalâ set of weights would be, we have to train more and more and again and again to kind of pull those weights back in. Thus training is sped up, reducing processing time. (or at least this is my thinking)
As to I/O embeddings, yes, at least in the classical interpretation, you are correct. Your right shift of your target language is your âteacher forcingâ-- Or basically by moving everything in the embedding one step ahead, you are trying to force it to predict the ânext wordâ.
*edit: I checked another source for reference and may have spoken incorrectly-- Skip connections solve the vanishing gradient problem, not the exploding one.
skip connection and residual connection are not totally same thing.
you really donât need to send DM, we can have healthy discussion here as the learner can also learn from this.
Both skip and residual connections enable gradients to flow better, but skip connections directly merge features from different layers, while residual connections add the original input to the transformed output, thereby better propagation of gradient and neural training.
if you notice the computation between skip connections and residual connection it is not same.
Residual connection does element-wise computation where skip connections concatenate.
residual connections learn residual functions based on layer inputs, while skip connections skip some layers and feed output from one layer to the next.
The only similarity between skip connections and residual connections are both allow input to bypass layers and contribute directly to output.
Regards
DP