The TA says:
“One workaround is to use the encoder hidden states for each word instead of trying to smash it all into one big vector. But this model would have flaws with memory and contexts.” Around 4:50
Hope someone could explain in a little more detail; I do not understand why this model would have memory and contexts flaws. Thanks!
This is an excellent question!
It’s not obvious what Younes meant by these words but the way I interpret them is - if you feed each hidden state of the rnn encoder, the gradient calculations become very problematic.
For example, when feeding only the h4 to the decoder, the whole “It’s time for tea” sequence is represented with only one vector (somewhat similarly to humans - we do not store every our brain “state” for every word when we translate sentence, but only the “blob”. Though reference to humans is probably a mistake, but my intention is to illustrate the point - a single point of reference against which we generate the translation). Similarly, in traditional seq2seq we use only the last hidden state (or sum, or other operation - but a single representation). In this architecture, the h4 values are influenced by h1, h2, h3 (and x1, x2, x3, x4), and this “influence” is propagated according to loss function.
This would be the traditional architecture, but now, when you want to account for every hidden state at every point in time, this would add a lot of complexity. This diagram does not show how the decoder would handle every encoder’s hidden state, but we could assume that only the h3 would be responsible for generating decoder’s “du”. In this set up, instead of accumulating gradients for the whole translation, you now have to have a snapshot of every hidden state at the time the output from the decoder is generated and recalculate gradients from that point - this would be problematic, especially for long sequences as resources needed would rise exponentially).
Or in concrete terms, when generating the translation, you only had to back propagate through h4, h3, h2, h1 for the whole translation (1 loss = 1 time). But, now, you would have to back propagate through h4, h3, h2, h1 for the “the” word, through h3,h2,h1 for the “du” word (no h4 in this loss calculation), etc. The loss calculations would become complex - you would have to calculate loss for every token generated separately.
As for the “contexts”, this set up might be ok if tokens match (like in the example above - every translated word matches its counterpart) but often the translated tokens do not match and sometimes there are two words for a single word translation, etc. So, the decoder’s rnn (for example, with it’s own h16 and input of encoders h16) job is still challenging - to somehow pick out, that the meaning of a word to be translated was actually in input x7 of the encoder. In other words, the problem would remain - preserving the representation or “meaning” for word/words to be translated.
Of course, there could be other different set ups how to handle every encoder hidden state, but the problems would remain - accounting for every state separately is very compute intensive, and to figure out what is the right “context” for every word of translation would also require a huge dataset (couple examples would not suffice). While if your model tries to compress the meaning whole meaning as accurately as possible, you could get away with less redundant compute and smaller dataset (since language is already a compressed representation of “reality”).
Cheers
Hi Arvy,
Thank you for your explanation. My rough understanding based on what you are saying is that keeping such states in this form will complicate calculation on the order of N-factorial for N-long sequence:
N-long for last word for h1,…,hN x (N-1)-long for previous word x (N-2) etc etc.
But in case of attention, how is that simpler? It seems even more complicated… Unless some alphas are explicitly 0 and those h-terms can be completely dropped.
For context part I kind of understand your point, although again do not completely see how attention fundamentally helps. One just can say adding more (attention) variables alpha improves context representation – of course, more parameters, better fit.
DS
What do you mean by “alpha” in your post?
In any case, in some sense (conceptually), it is not simpler for attention. The number of attention scores over sequence rises quadratically (if context length is 10 tokens, then you have 10x10 attention matrix; if 100, then 100x100; this is an oversimplification for illustration). Transformers’ context length is an important limiting factor computation and memory wise (the reason why only big companies can train “big” LLMs).
But attention in transformers offers the advantage - matrix multiplications, which can be computed in parallel. In other words, instead of processing the whole context one token at a time, you can dot multiply the whole context with itself (for self-attention), or the whole german sequence with whole english sequence (for cross-attention). You cannot do that in rnn (well… except the very recent Mamba architecture, explaining which would require a course itself), since when training rnn to predict next token, you have to compute all the previous tokens sequentially, while the transformer architecture allows you train predicting any token in a sequence (given the context).
Cheers
Thanks, Arvy. I think slowly concepts sink into me. Attention adds to complexity somewhat and adds more parameters, and in a nice natural way allows model improvement (as it allows emphasis of particular words in a sequence when predicting translation), while it is structured in such a way that implementation is totally vectorized and efficient.
But then I can say may be computation is faster, but memory requirement is as high or higher than the one shown in “Use all the encoder hidden states?” slide?
Thank you
DS
Yes @Dennis_Sinitsky I think you understand the concepts correctly.
- Prior to invention of “attention”, the decoder would be using only the last state of the encoder at every step as the compressed representation of the sequence (which would be ok for short sequences, but a problem for long sequences).
- The “attention” was invented to let the decoder receive different compressed representation of the sequence at each step - it adds complexity but it adds more “channels” for information to propagate. (In your diagram, this would be a box which receives the 4 arrows from h1, h2, h3, h4, scores them against decoder already translated tokens (say h1, h2), to result in a single but different from previous step representation for translating token number 3).
- Later, transformers got rid of the rnns altogether by introducing positional encoding and relying only on the embeddings and scoring them with attention.
Regarding the memory - it very depends on the architecture. If just adding the attention layer and keeping everything else identical - yes, that would require more memory. But attention allows for a more efficient information flow, so you might get away with smaller rnns (which, for example, in LSTM case, are compute and memory hungry). So at the end it depends.
And as I said, the diagram above does not show what happens with every state - it could be the same memory requirement (since attention takes in these arrows too), it could be more (if there’s a more complex use of each state than attention), or it could be less (if someone comes up with a more efficient use of each state).
Cheers
Thank you, Arvy. Now I am doing Transformer week; getting concepts sunk into me even more
As I understand, modern AI is all about transformers, whether for vision (with CNN) and language (with positional encoding). RNN’s, including LSTM’s are becoming a thing of the past. Is that so?
Yes, pretty much you could say that. But, there are new developments every day.
The Mamba architecture that I mentioned (which is not technically an RNN but has relations to it (see the B.3 Appendix of the paper that I linked)) seems promising.
If you’re interested, you can also learn more about it in this video in broad terms or in this video in more concrete terms after finishing the course.
There are other promising techniques which pop up now and then. Time will tell if they dethrone the transformers.
Cheers
Thanks, Arvy! So much information to pack into my small brain!