On the last line:
# Compute masked accuracy (quotient between the total matches and the total valid values, i.e., the amount of non-masked values)
we do not specify the axis parameter. In other words, we want to calculate all the accurately predicted labels (in our case by summing matches_true_pred
which accounts for padded tokens) and divide by the number of all the elements (in our case by summing mask
, which accounts for padded tokens). And both of these are scalars (not vectors).
Cheers