This unit test is a bit contrived (most probably accidentally, but also could be intentionally).
Let me explain the last test case (2,1,7,3) that you’re having problems with. (Usually, this is not a normal input shape in NLP but it is used in the test case).
So, in the unit test case the preds are:
<tf.Tensor: shape=(2, 1, 7, 3), dtype=float32, numpy=
array([[[[0.1 , 0.5 , 0.4 ],
[0.05, 0.9 , 0.05],
[0.2 , 0.3 , 0.5 ],
[0.1 , 0.2 , 0.7 ],
[0.2 , 0.8 , 0.1 ],
[0.4 , 0.4 , 0.2 ],
[0.5 , 0. , 0.5 ]]],
[[[0.1 , 0.5 , 0.4 ],
[0.2 , 0.8 , 0.1 ],
[0.4 , 0.4 , 0.2 ],
[0.5 , 0. , 0.5 ],
[0.05, 0.9 , 0.05],
[0.2 , 0.3 , 0.5 ],
[0.1 , 0.2 , 0.7 ]]]], dtype=float32)>
the targets are:
<tf.Tensor: shape=(2, 7), dtype=int32, numpy=
array([[1, 2, 0, 2, 0, 2, 0],
[2, 1, 1, 2, 2, 0, 0]], dtype=int32)>
when you one_hot encode the targets, you get:
<tf.Tensor: shape=(2, 7, 3), dtype=float32, numpy=
array([[[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[1., 0., 0.]],
[[0., 0., 1.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[1., 0., 0.],
[1., 0., 0.]]], dtype=float32)>
when you multiply that with preds, the broadcasting comes into play (so, you get (2,2,7,3):
<tf.Tensor: shape=(2, 2, 7, 3), dtype=float32, numpy=
array([[[[0. , 0.5 , 0. ],
[0. , 0. , 0.05],
[0.2 , 0. , 0. ],
[0. , 0. , 0.7 ],
[0.2 , 0. , 0. ],
[0. , 0. , 0.2 ],
[0.5 , 0. , 0. ]],
[[0. , 0. , 0.4 ],
[0. , 0.9 , 0. ],
[0. , 0.3 , 0. ],
[0. , 0. , 0.7 ],
[0. , 0. , 0.1 ],
[0.4 , 0. , 0. ],
[0.5 , 0. , 0. ]]],
[[[0. , 0.5 , 0. ],
[0. , 0. , 0.1 ],
[0.4 , 0. , 0. ],
[0. , 0. , 0.5 ],
[0.05, 0. , 0. ],
[0. , 0. , 0.5 ],
[0.1 , 0. , 0. ]],
[[0. , 0. , 0.4 ],
[0. , 0.8 , 0. ],
[0. , 0.4 , 0. ],
[0. , 0. , 0.5 ],
[0. , 0. , 0.05],
[0.2 , 0. , 0. ],
[0.1 , 0. , 0. ]]]], dtype=float32)>
and when you sum over the last axis, you get:
array([[[0.5 , 0.05, 0.2 , 0.7 , 0.2 , 0.2 , 0.5 ],
[0.4 , 0.9 , 0.3 , 0.7 , 0.1 , 0.4 , 0.5 ]],
[[0.5 , 0.1 , 0.4 , 0.5 , 0.05, 0.5 , 0.1 ],
[0.4 , 0.8 , 0.4 , 0.5 , 0.05, 0.2 , 0.1 ]]], dtype=float32)
# log_p.shape
# (2, 2, 7)
Just to continue the exercise further (it might help you or others understand what is asked from you):
# Identify non-padding elements in the target
array([[0., 1., 1., 1., 1., 1., 1.],
[1., 0., 0., 1., 1., 1., 1.]])
# non_pad.shape
# (2, 7)
then (in this test case, the broadcasting again comes into play):
# Apply non-padding mask to log probabilities to exclude padding
array([[[0. , 0.05 , 0.2 , 0.69999999, 0.2 ,
0.2 , 0.5 ],
[0.40000001, 0. , 0. , 0.69999999, 0.1 ,
0.40000001, 0.5 ]],
[[0. , 0.1 , 0.40000001, 0.5 , 0.05 ,
0.5 , 0.1 ],
[0.40000001, 0. , 0. , 0.5 , 0.05 ,
0.2 , 0.1 ]]])
# log_p.shape
# (2, 2, 7)
then:
# Calculate the log perplexity by taking the sum of log probabilities and dividing by the sum of non-padding elements
# numerator:
array([[1.85 , 2.1 ],
[1.65000001, 1.25000001]])
# .shape
# (2, 2)
# denominator:
array([6., 5.])
# .shape
# (2,)
# log_ppx
array([[0.30833333, 0.42 ],
[0.275 , 0.25 ]])
# .shape
# (2, 2)
lastly:
# Compute the mean of log perplexity
# log_ppx
0.31333333427707355
Again, I reiterate that his unit test is contrived and the number of dimensions should have been 3 (batch being the first) but this might help you get the idea.
Cheers