I am passing first 8 examples in testing, but fail 9-th example. This is printout of some dimensions and variables, followed by error message. In particular, I am failing when preds shape is 4-dimentional tensor while 1-hot is 3-dimensional, which originally gave me shape mismatch error which I bypassed by doing this:
if preds.shape != one_hot_target.shape:
preds = tf.squeeze(preds, axis=1)
print(f"preds shape after squeeze: {preds.shape}\npreds after squeeze:\n {preds}")
But my result still disagrees. And in general, I do not understand why would preds tensor suddenly have an extra dimension in exampmle 9 as compared with examples 1-8?
This below is printout of my debug print statements and error message:
num_classes: 3
preds shape: (2, 1, 7, 3)
preds: [[[[0.1 0.5 0.4 ]
[0.05 0.9 0.05]
[0.2 0.3 0.5 ]
[0.1 0.2 0.7 ]
[0.2 0.8 0.1 ]
[0.4 0.4 0.2 ]
[0.5 0. 0.5 ]]]
[[[0.1 0.5 0.4 ]
[0.2 0.8 0.1 ]
[0.4 0.4 0.2 ]
[0.5 0. 0.5 ]
[0.05 0.9 0.05]
[0.2 0.3 0.5 ]
[0.1 0.2 0.7 ]]]]
targets shape: (2, 7)
target: [[1 2 0 2 0 2 0]
[2 1 1 2 2 0 0]]
1-hot target shape: (2, 7, 3)
1-hot target: [[[0. 1. 0.]
[0. 0. 1.]
[1. 0. 0.]
[0. 0. 1.]
[1. 0. 0.]
[0. 0. 1.]
[1. 0. 0.]]
[[0. 0. 1.]
[0. 1. 0.]
[0. 1. 0.]
[0. 0. 1.]
[0. 0. 1.]
[1. 0. 0.]
[1. 0. 0.]]]
preds shape after squeeze: (2, 7, 3)
preds after squeeze:
[[[0.1 0.5 0.4 ]
[0.05 0.9 0.05]
[0.2 0.3 0.5 ]
[0.1 0.2 0.7 ]
[0.2 0.8 0.1 ]
[0.4 0.4 0.2 ]
[0.5 0. 0.5 ]]
[[0.1 0.5 0.4 ]
[0.2 0.8 0.1 ]
[0.4 0.4 0.2 ]
[0.5 0. 0.5 ]
[0.05 0.9 0.05]
[0.2 0.3 0.5 ]
[0.1 0.2 0.7 ]]]
log_p: [[0.5 0.05 0.2 0.7 0.2 0.2 0.5 ]
[0.4 0.8 0.4 0.5 0.05 0.2 0.1 ]]
non_pad: [[0. 1. 1. 1. 1. 1. 1.]
[1. 0. 0. 1. 1. 1. 1.]]
log_p*non_pad: [[0. 0.05 0.2 0.7 0.2 0.2 0.5 ]
[0.4 0. 0. 0.5 0.05 0.2 0.1 ]]
num: [1.85 1.25]
denom: [[6.]
[5.]]
log_ppx_a: [[0.30833334 0.20833333]
[0.37 0.25 ]]
log_p shape: (2, 7)
log_p: [[0. 0.05 0.2 0.7 0.2 0.2 0.5 ]
[0.4 0. 0. 0.5 0.05 0.2 0.1 ]]
log_ppx mean: 0.28416666885217035
AssertionError Traceback (most recent call last)
Cell In[56], line 2
1 #UNIT TESTS
----> 2 w1_unittest.test_test_model(log_perplexity)
File /tf/w1_unittest.py:308, in test_test_model(target)
306 assert np.isnan(output), f"Fail in {testi[‘name’]}. Expected {expected} but got {output}"
307 else:
→ 308 assert np.allclose(output, expected), f"Fail in {testi[‘name’]}. Expected {expected} but got {output}"
310 print(“\n\033[92mAll test passed!”)
AssertionError: Fail in Example 9. Batch of 2. Expected -0.31333333427707355 but got -0.28416666885217035