Backpropagation Through Time and Vanishing Gradient (RNN)

Hi,

My question is a little bit complicated. Please bear with me for a moment. Suppose that we have a RNN structure for three input sequences (many-to-many).

The formula we calculate the gradients (\partial L / \partial W_{s}) are:

\dfrac{\partial L}{\partial W_{s}} = \dfrac{\partial L^{<3>}}{\partial y^{<3>}} \dfrac{\partial y^{<3>}}{\partial s^{<3>}} \dfrac{\partial s^{<3>}}{\partial W_s} + \dfrac{\partial L^{<3>}}{\partial y^{<3>}} \dfrac{\partial y^{<3>}}{\partial s^{<3>}} \dfrac{\partial s^{<3>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial W_s} + \dfrac{\partial L^{<3>}}{\partial y^{<3>}} \dfrac{\partial y^{<3>}}{\partial s^{<3>}} \dfrac{\partial s^{<3>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial s^{<1>}} \dfrac{\partial s^{<1>}}{\partial W_s} \\ + \dfrac{\partial L^{<2>}}{\partial y^{<2>}} \dfrac{\partial y^{<2>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial W_s} + \dfrac{\partial L^{<2>}}{\partial y^{<2>}} \dfrac{\partial y^{<2>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial s^{<1>}} \dfrac{\partial s^{<1>}}{\partial W_s} \\ + \dfrac{\partial L^{<1>}}{\partial y^{<1>}} \dfrac{\partial y^{<1>}}{\partial s^{<1>}} \dfrac{\partial s^{<1>}}{\partial W_s}

Similar formula can be applied when calculation \partial L / \partial W_x:

\dfrac{\partial L}{\partial W_{x}} = \dfrac{\partial L^{<3>}}{\partial y^{<3>}} \dfrac{\partial y^{<3>}}{\partial s^{<3>}} \dfrac{\partial s^{<3>}}{\partial W_x} + \dfrac{\partial L^{<3>}}{\partial y^{<3>}} \dfrac{\partial y^{<3>}}{\partial s^{<3>}} \dfrac{\partial s^{<3>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial W_x} + \dfrac{\partial L^{<3>}}{\partial y^{<3>}} \dfrac{\partial y^{<3>}}{\partial s^{<3>}} \dfrac{\partial s^{<3>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial s^{<1>}} \dfrac{\partial s^{<1>}}{\partial W_x} \\ + \dfrac{\partial L^{<2>}}{\partial y^{<2>}} \dfrac{\partial y^{<2>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial W_x} + \dfrac{\partial L^{<2>}}{\partial y^{<2>}} \dfrac{\partial y^{<2>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial s^{<1>}} \dfrac{\partial s^{<1>}}{\partial W_x} \\ + \dfrac{\partial L^{<1>}}{\partial y^{<1>}} \dfrac{\partial y^{<1>}}{\partial s^{<1>}} \dfrac{\partial s^{<1>}}{\partial W_x}

When vanishing gradient is taught, it is said that the effect of gradient becomes negligible for earlier time steps. That sounds reasonable when we are computing the third term in the first lines of equations above. To be clear about it, I am going to write that term again below.

\dfrac{\partial L^{<3>}}{\partial y^{<3>}} \dfrac{\partial y^{<3>}}{\partial s^{<3>}} \dfrac{\partial s^{<3>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial s^{<1>}} \dfrac{\partial s^{<1>}}{\partial W_s} \\ \dfrac{\partial L^{<3>}}{\partial y^{<3>}} \dfrac{\partial y^{<3>}}{\partial s^{<3>}} \dfrac{\partial s^{<3>}}{\partial s^{<2>}} \dfrac{\partial s^{<2>}}{\partial s^{<1>}} \dfrac{\partial s^{<1>}}{\partial W_x}

1-Taking into consideration that these gradients are not only comprised of long-chained derivatives but also one-time chained or two-times chained derivatives, how does vanishing gradient affect the total gradients?
2-Even though long-chained terms are close to zero, we still have one-time chained or two-times chained derivatives. Those should not be vanished right? Since we take the sum all these terms, how come does overall gradient vanish?
3-How about we have many-to-one structure or different RNN structure?
4-What role does parameter sharing play for this particular problem?

Thanks!

PS: I got the image from geeksforgeeks. The model prediction for the second time step should be Y_2 not Y_3. Also, I chose the RNN model with three time step for simplicity. My question can be extended to the higher time steps.

Hi @sahina,

You can see the paper on LSTM {LSTM paper}. These discusses the problems and the solutions for them.
When you have a longer chain, each of the terms in the chain rule multiplies a W. So when W has an eigenvalue greater than 1, it can lead to exploding gradients. Adding terms would increase it even more!
Hope this helps.

1 Like

Hi @thearkamitra,

Thank you so much for your response. So let me rephrase it. The vanishing gradient issue lies in \dfrac{\partial s^{<t+1>}}{\partial s^{<t>}} term. I guess W term comes in this chained gradients. If so, vanishing gradient still should be effective when chain is longer. However, the overall derivative consists of not only longer chains but also shorter chains as well. Since we share the parameters all across the sequence, don’t you think contribution from shorter chains can compensate gradient in the desirable direction?

Hi @sahina,

I am not sure of the answer for gradient vanishing, but for exploding gradients, I hope you realize why this would be an issue!

Hi @sahina,

think of an RNN for language modeling and Andrew’s example “The cat, which ate a lot of chocolate cookies, were full and left the dinner untouched.”, i.e. the current RNN produces \hat{y}^{<10>} that has a higher probability for “were” compared to “was”. Thus \mathcal{L}^{<10>}(\hat{y}^{<10>}, “was”) produces a high error, but all \mathcal{L}^{<t>}(...), t<10 are low. Thus Thus \mathcal{L}^{<10>}(\hat{y}^{<10>}, “was”) must be the cause for the weights being updated in such a way that the RNN stores singular as information in the hidden states a^{<2>}, .. a^{<10>} when it sees “cat”.

What I do not understand: Since each time step uses the same weights, why must backprop propagate the error \mathcal{L}^{<10>}(\hat{y}^{<10>}, “was”) back to t=2 for the relevant weight update? Why not update the weights directly for t=10?

@paulinpaloalto do you happen to have a good explanation? I would really like to understand why the gradient needs to be propagated back from t=10 (with the word “was” vs “were”) to t=2 (with the word “cat”) for the weights to be updated in such a way that the RNN is able to memorize singular vs plural in any time step.

This is just my interpretation, which is probably worth exactly what you paid for it, but I’d say that the point is not that it needs to be “propagated” from t = 10 back to t = 2, it’s that gradients get generated by the errors at every time step, right? And then we apply them (as you say) to the one shared set of weights. Of course as we discussed very recently on this other thread, the manner in which we are actually applying the gradients is arguably a bit sloppy. But it seems to work. “Close enough for jazz” apparently … :nerd_face:

The point about state being coordinated between two disparate timesteps is what LSTM is specifically designed to facilitate. Of course the weights for the various LSTM “gates” are included in what we are updating.

Thanks @paulinpaloalto, I have exactly that mental model (gradients get generated by the errors at every time step), which led me to wonder why long-term dependencies are a problem in RNNs.

If I understood correctly, long-term dependencies are just

But since weights from t=2 and t=10 are shared, can the weights not be updated at t=10 for the RNN to learn to compute and memorize the feature “singular vs plural”?

In a fully connected NN, I imagine backprop needs to propagate the error back to the layer where the feature “singular vs plural” is computed, but in an RNN the weights are shared and the feature “singular vs plural” needs to be computable for each x^{<t>}. So for learning the feature “singular vs plural” in the RNN, why does the error need to propagate back to t=2, to see x^{<2>}=“cat”, where the prediction was already correct anyway?

I think I am having difficulties imagining how an RNN is extracting more and more abstract features from the input (as a deep neural network) while simultaneously reading more and more input from the input sequence – all with the same shared weights :-0

Yes, I think at some level there is that intuitive difficulty with all forms of Neural Networks: at a very fundamental level it all just seems like magic that an algorithm can “learn” patterns as complex as it apparently can. :grin: The thing I lean on particularly in this instance of RNNs figuring out patterns that span many time steps is that the state they have to work with is pretty complex. And of course the number of total elements in that state and how it is divided among the base RNN state and the various LSTM gates are all choices that we have to make as the system designers and (one assumes) that it must be possible to make those choices incorrectly. Of course it contributes to that sense of magic that we don’t even have to in any way tell it how to apportion that state to different purposes. It just figures it out. I don’t know if there is any work for RNNs that is analogous to the interesting work Prof Ng describes in Course 4 Week 4 in the lecture titled “What Are Deep ConvNets Learning?” I found that pretty illuminating in the case of ConvNets: even if it’s not that satisfying in terms of showing how those patterns are actually learned, you can see that it really did learn very specific things. Did you take C4 and remember that lecture? If not, it’s really worth a look. It also exists as a YouTube video. Let me know if you need a link.

Thanks, @paulinpaloalto, keeping in mind that the state can be very complex does help me to get some intuition. Furthermore, Hochreiter and Schmidhuber’s original LSTM paper from 1997 talk about weight conflicts in RNNs, which I think is related to my question about the simultaneous tasks that RNNs have to perfom. This helps me get further intuition about gating to reduce the problem with long term dependencies.

Thanks also for the pointer – I haven’t done course 4 yet, but have put it on top of my pile of shame now :wink:

If you’d like to get a sense of what is covered in that lecture without taking all of Course 4, Prof Andrew has made it available on YouTube.