Number of LSTM units in Trax

First thing to note - LSTM units do not depend on sentence length - that should be fully understood in order to talk about Embedding and LSTM dimensions (they do not directly depend on sequence length, they process one token embedding at a time). Check my simple example on how LSTM process a sequence after embedding.

On the Embedding dimensionality correlation with LSTM dimensionality - it is not uncommon to have equal Embedding and LSTM dimension size but it is not a must. I think it is trax library’s design choice to not implement this functionality because of a lack of demand from ML comunity or smth else… I don’t know and nobody since then commented on that.

What is always the most important thing is the model’s performance and it depends on many things (especially data quality). Some of the many things you can do (but not limited to) is:

  • You can increase or decrease Vocabulary size (by dropping some words for “UNK” etc.) but it should not directly influence your Embedding size. Usually your vocab size will be 10 000 - 100 000.
  • You can increase or decrease Embedding size - compressing more or less word’s “expressability” (lexics → meaning) to a certain amount of numbers (each word’s meaning reduced to (vector length) dimensions). Quality of word embedding increases with higher dimensionality. But at some point the gains diminish. You need trial and error to find the right size for you. Also, your Embedding size will depend on tokens (words vs. subwords vs. characters), but usually it is 256 - 1 024.
  • You can increase or decrease LSTM size - how to further compress or expand model’s ability to carry compressed information from-step-to-step (word to word in sentence). (note also, that number of layers can be a factor how hierarchically propagate the information. Also directionality is a factor.). This size mostly should depend on what is after this layer (is this layer the encoder, or is it for categorical predictions (is there a Dense layer after it and how big?), or smth else.).

Usually it would depend on the problem you are trying to solve (profit) and on costs (hardware / time / electricity) - smaller models tend to be cheaper and faster but tend to lack accuracy.