How does trax handle Y variable?

Hi @sachin_B_S

Actually model outputs are not V1 and V2 matrices (these are outputs of data_generator), model outputs are predicted similarities between V1_1 and V2_1, V1_2 and V2_2 and so on. In other words, as a concrete example (from # UNQ_C1 “Expected output”), model tries it’s best to make input V1_1:

[  30   87   78  134 2132 1981   28   78  594   21    1    1    1    1
     1    1]

to be as similar as possible to V2_1:

[  30  156   78  134 2132 9508   21    1    1    1    1    1    1    1
     1    1]

but as different as possible from V2_2:

[  30  156   78 3541 1460  131   56  253   21    1    1    1    1    1
     1    1]

Your second question:

How does trax differentiate b/w inputs and ground truth when passing values to loss functions? For Ex, CategoricalCrossEntropy needs ground truth but the custom triplet loss does not require it.

We designed TripletLossFn in such a way, that trax can now know how “good” the model is performing - TripletLossFn outputs a single number: big number = “bad”, small_number = “good”. Trax then adjusts the model weights accordingly.

You might find this answer helpful to understand the details of TripletLossFn calculations (and how it decides if model outputs are good or not).

Cheers.