W4 Triplet loss reduce sum axis

I can get triplet loss programming assignment to work by calling reduce_sum with axis = -1 when calculating positive distance.

When I tried reduce_sum with axis = 1 which I think to be equivalent to axis = -1 since input tensor to reduce_sum is 2 dim of shape (3,128) so when setting axis =1 it should sum along columns (horizontally). However, I got this error
“InvalidArgumentError: Invalid reduction dimension (1 for input with 1 dimension(s) [Op:Sum]”
I wondered why reduce_sum with axis = 1 did not work in this case.

Setting axis=0, also did not work obviously because it sums along row and result in 128 instead of 3 (m # of samples)

1 Like

Well, check the dimensions of the object you are feeding to reduce_sum. It must be different than you expect. But from the error message, it sounds like it only has one dimension. If that is true, then axis -1 should be equivalent to axis 0, right?

[Unit Test 1]
tf.random.set_seed(1)
y_true = (None, None, None) # It is not used
y_pred = (tf.keras.backend.random_normal([3, 128], mean=6, stddev=0.1, seed = 1),
tf.keras.backend.random_normal([3, 128], mean=1, stddev=1, seed = 1),
tf.keras.backend.random_normal([3, 128], mean=3, stddev=4, seed = 1))
loss = triplet_loss(y_true, y_pred)

[Unit Test 2]
y_pred_perfect = ([1., 1.], [1., 1.], [1., 1.,])
loss = triplet_loss(y_true, y_pred_perfect, 5)

Thanks for the hint from Paul. I found why axis=1 didn’t work. Normally, for case (m,128) in unit test 1, axis=1 should be fine. Yet, the unit test 2 in the assignment use list as input, NOT tensor. Therefore, axis=1 did not work. And if you print out all the variables’ shape, you will hit error

AttributeError: ‘list’ object has no attribute ‘shape’

Yet, should we assume that in real application, the input should be a tensor? Therefore, axis=1 should be the correct answer instead?

1 Like

The real point of my hint was that the input tensors have different numbers of dimensions in the test cases they give us, so using axis = 1 to mean the last axis does not work. It’s the equivalent of “hard-coding”. But if you use axis = -1 instead, then it selects the last axis regardless of how many there are.

1 Like