EXPERT QUESTION - How to buffer repeated layer calls?

In the graded assignment this week, the input_encoder_fn is called repeatedly for each output token during generation. So to reduce overhead, I tried to wrap the layer in a buffer so that it would only execute once. Then JAX informs me in an error message later that this is illegal, presumably (duh) because it can’t backpropagate through a global variable that was created in previous passes. (No good deed goes unpunished.) Anyone have another approach to eliminating the redundant forward passes over the exact same input token string for every output token?

Here’s what I tried that didn’t work:

class Buffer_layer(tl.Serial):
    def init_weights_and_state(self,*args,**kwargs):
        self.__last_inputs = None
        tl.Serial.init_weights_and_state(self,*args,**kwargs)
    def forward(self,*args,**kwargs):
        if self.__last_inputs != args[0]:
            self.__last_inputs = args[0]
            self.__last_outputs = tl.Serial.forward(self,*args,**kwargs)
        return self.__last_outputs

Hi Ken_Otwell,

This is certainly fun but lies outside the scope of the setup of the course. The way I see the assignment is as a step towards the transformer model, in which things are done differently. But I hope you found a solution to your puzzle. :slight_smile: