Number of LSTM units in Trax

No, as I said, it is very important to understand that sequence length does not influence RNN (including LSTM) dimensions (at least directly). In other words, it doesn’t matter if the longest sentence in your dataset is 1 000 words or 30.

It is better to explain with simple RNN, because LSTM has more weight matrices and activations. Also LSTM has one more hidden state which makes the explanations more complicated. A simple RNN would illustrate the main idea better, which works in a simillar way to an LSTM.

Note, that here, for the sake of illustration (to go from top to bottom and to save space), the dimensions of vectors and matrices are different than they would be in a real world. Leaving this aside, the most important thing to see is that all the weights are same for each step (word). (The RNN weights are in a yellow frame).

For example, I highlighted the second step:

  1. after embedding, the input x_2 is of shape (1, 5) and it is multiplied by W_{xh} (shape 5, 4), which produces the shape (1, 4);
  2. previous hidden state h_0 is of shape (1, 4), and it is multiplied by W_{hh} (shape 4, 4), which produces the shape (1, 4);
  3. both outputs are summed, which does not change shape (1, 4);
  4. tanh is applied on the sum, which also does not change shape (1, 4);

The result is h_2 - hidden state for the next step