How does TrainTask consume the training tuple?

Hi team,

I noticed that after pre-processing the training stream to output the following tuple: (joint, joint, mask), we just feed it into the TrainTask and don’t need to deal with it anymore.

I tried to understand how the TrainTask consumes these tuples in order to properly mask the output and then use them in the training process but could not find any further information.

Is there documentation on this or could you link to the relevant part of the trax code for me to read?

Many thanks!

This link might be of help to you:

https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#module-trax.supervised.training