How does trax handle Y variable?

Is it alright if we don’t pass an explicit output label from data generator to TrainTask?

In C3-W(1-2-3), the data had explicit output variable which was used during loss computation.

But in C3-W4 even though we have the explicit Y variable(duplicate or not), we don’t use it to calculate loss. Here V1 and V2 matrices are model outputs and NOT ground truth.

How does trax differentiate b/w inputs and ground truth when passing values to loss functions? For Ex, CategoricalCrossEntropy needs ground truth but the custom triplet loss does not require it.

I also tried to generate a constant number n (Fox. Ex → 5) along with q1 and q2 in data generator(q1, q2, n) and defined the loss like this:

def loss(v1, v2, n, margin):
do_some_stuff_with_v1_v2_n_margin

But i saw that n was not automatically passed to loss function from data generator.
So, i want to ask what parts of items from data generator and what parts of items from model output is passed to loss function in TrainTask?

Hi @sachin_B_S

Actually model outputs are not V1 and V2 matrices (these are outputs of data_generator), model outputs are predicted similarities between V1_1 and V2_1, V1_2 and V2_2 and so on. In other words, as a concrete example (from # UNQ_C1 “Expected output”), model tries it’s best to make input V1_1:

[  30   87   78  134 2132 1981   28   78  594   21    1    1    1    1
     1    1]

to be as similar as possible to V2_1:

[  30  156   78  134 2132 9508   21    1    1    1    1    1    1    1
     1    1]

but as different as possible from V2_2:

[  30  156   78 3541 1460  131   56  253   21    1    1    1    1    1
     1    1]

Your second question:

How does trax differentiate b/w inputs and ground truth when passing values to loss functions? For Ex, CategoricalCrossEntropy needs ground truth but the custom triplet loss does not require it.

We designed TripletLossFn in such a way, that trax can now know how “good” the model is performing - TripletLossFn outputs a single number: big number = “bad”, small_number = “good”. Trax then adjusts the model weights accordingly.

You might find this answer helpful to understand the details of TripletLossFn calculations (and how it decides if model outputs are good or not).

Cheers.