I added these print statements after every relevant line:
print(f'self_mha_output:{self_mha_output}')
print(f'skip_x_attention:{skip_x_attention}')
print(f'ffn_output:{ffn_output}')
print(f'ffn_output after dropout:{ffn_output}')
print(f'encoder_layer_out:{encoder_layer_out}')
Here is the first set of output:
self_mha_output:[[[ 0.2629684 0.5438655 -0.47695604 0.43180236]
[ 0.27214473 0.5516315 -0.47251672 0.44105405]
[ 0.2637157 0.5352751 -0.46818826 0.44008902]]]
skip_x_attention:[[[ 0.7840514 -0.9639456 -1.0145587 1.1944535 ]
[-1.2134784 1.0835364 -0.7550787 0.885021 ]
[ 0.76012594 -0.20960009 -1.545446 0.9949202 ]]]
ffn_output:[[[-0.40299335 -0.26304182 0.01199517 0.77515805]
[ 0.11928089 0.02366283 0.21244505 0.6133719 ]
[-0.47993705 -0.35966852 0.11620045 0.9476139 ]]]
ffn_output after dropout:[[[-0.44777042 -0.2922687 0.01332797 0.86128676]
[ 0.13253433 0.02629204 0.23605007 0.6815244 ]
[-0.5332634 -0.3996317 0. 1.0529044 ]]]
encoder_layer_out:[[[ 0.23017097 -0.9810039 -0.78707564 1.5379086 ]
[-1.2280797 0.76477575 -0.7169284 1.1802323 ]
[ 0.14880148 -0.4831803 -1.1908401 1.5252188 ]]]
You may try this and see where your values are not matching with mines. That is the place to look out. But make sure you add these statements in the correct places.
Best,
Saif.