Not understanding how Reversible Layers in the Reformer saves memory

I follow the equations for Reversible Layers in the Reformer model, and I kinda get what the goal is that we’re trying to do, but I’m not following how this saves memory.

For example, from the final programming assignment, it states:

As you can see, it requires that x and ya be saved so it can be used during backpropagation. We want to avoid this to conserve memory and this is where reversible residual connections come in. They are shown in the middle and rightmost diagrams above. The key idea is that we will start with two copies of the input to the model and at each layer we will only update one of them. The activations that we don’t update are the ones that will be used to compute the residuals.

< Bunch of equations omitted >

With this configuration, we’re now able to run the network fully in reverse. You’ll notice that during the backward pass, x2 and x1 can be recomputed based solely on the values of y2 and y1. No need to save it during the forward pass.

So it seems to me that, by trying to avoid saving variables x and ya, we make a copy of variable x, splitting it into x1 and x2, and have to save results y2 and y1 instead. Didn’t we just cancel out the memory savings by having to save two new variables?

In other words, since these variables are all of same dimensionality, isn’t it the same memory to save double the outputs instead? We’ve just replaced x / ya with y1 / y2 it seems… so what am I missing? Thanks.

If you have saved variables there is no need to calculate again and there is a lot of memory involved in computational calculations!

Hi @Joel_Wigton

That is a good question and the answer lies in the FeedForward layer - its dimensionality is multiple times higher than the d_model (d_model = 1024, d_ff = 4096). In expense of compute we can save the memory (about 1.5-2× more expensive on compute, depending on the implementation).

You should read the Reformer paper - it is very approachable and explains the ideas (both, the parameter reuse and also chunking for memory savings) very well.
Also, you could check out The Reversible Residual Network paper for more details.


1 Like

Thanks, I will give them both a read!

1 Like