what does trax.data.AddLossWeights do?
From the docs and code repository:
- If the stream consists of pairs `(inputs, targets)`, a loss mask is added
that is creates as a tensor of ones of the same shape as targets.
- If `id_to_mask` is not `None`, and the stream (after the previous point)
has triples `(inputs, targets, weights)`, the weights are multiplied by a
0/1 mask that is 0 iff targets is equal to `id_to_mask` (1 otherwise).
In simple words - model does not try to learn padding (id_to_mask).