My knowledge of Keras is modest and rusty so I’m not sure, but I guess you are right. Anyways, don’t take my word for it
In this case trax default behaviour is more like PyTorch (check the “Outputs” section):
What is kept after each step (word/token) is one output - h (other hidden state c (long memory) is dropped).
For me, it’s easier to understand form trax code than the documentation. If you would follow the code carefully, you could see what LSTM (layer) does:
return cb.Serial(
cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
cb.Select([0], n_in=2), # Drop RNN state.
name=f'LSTM_{n_units}', sublayers_to_print=[])
The Scan function applies LSTMCell function progressively and keeps only the first output (cb.Select([0], n_in=2)
)
That first output is new_h
that the LSTMCell forward
method produces.
Also, you might want to check my attempt at explaining how LSTM matrix calculations are done here.