It looks like there are (at least) two problems:
- You missed the fact that the labels and logits need to be transposed. Here is a thread with a checklist for this function. Here’s a thread which explains why the transpose is required.
- There is an extra dimension on one of your tensors. I assume it’s the labels tensor, since it’s the first operand. Note that the
new_y_train
value is generated by calling your earlierone_hot_matrix
function. So this may indicate a problem with that function that somehow was not caught by the test cases.
I added print statements to my compute_total_loss
code to show the shapes of the inputs before any processing of them (e.g. transpose) and here’s what I see:
labels [[0. 1.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[1. 0.]]
before logits.shape (6, 2)
before labels.shape (6, 2)