If I may one for question: Regarding trax’s LSTM implementation, the layer actually has several inputs/outputs (c, a, y hat, t), right? Does it simply “know” how to connect and forward stacked LSTMs? What if we wanted to inspect or use somehow these various inputs/outputs from the middle or final layers?
Actually decoder uses n_decoder_layers in NMTAttn() function in UNQ_C4 step 7. pre_attention_decoder_fn() on the other hand uses a single LSTM layer. It’s a design choice but you “might see” why it makes sense (to not use multiple layers on pre-attention and then again use multiple layers in NMTAttn).
Short answer - yes. Each LSTM cell get the same number (shape) of inputs and outputs and that is why you need to initialize hidden states - for the very first step (trax also helps you with initialization too by creating shapes and values of initial hidden states).
There are number of ways you could do that and they vary in simplicity. If you would want to just “check” them you could simply “debug” the code and follow trace.
You could also inherit the class and write your own forward function, maybe dumping values to some file.
Also there are more sophisticated ways of doing it which would probably require a whole course on it