C4_w2 UNQ_C8 Train Loop

I looked at all the threads, tried everything but still having issue and getting the following error:
I cannot figure it out. All my previous code passed the unit-test. I made sure to check that compute_attention_heads_closure, and compute_attention_output_closure have the right parameters passed to. I also make sure that Dropout(rate=dropout, mode=mode), but still I cannot figure this out. Any help will be appreciated.

---------------------------------------------------------------------------
LayerError                                Traceback (most recent call last)
<ipython-input-163-d1e9ffd2f1b1> in <module>
      1 # UNIT TEST
      2 # test training_loop
----> 3 w2_tests.test_training_loop(training_loop, TransformerLM)

~/work/w2_tests.py in test_training_loop(target, TransformerLM)
    813         os.remove("~/model/model.pkl.gz")
    814 
--> 815     output_loop = target(TransformerLM, my_gen(), my_gen())
    816 
    817     try:

<ipython-input-162-72ff593512a9> in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
     40                          train_task,
     41                          eval_tasks=[eval_task],
---> 42                          output_dir=output_dir)
     43 
     44     return loop

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    246     if not use_memory_efficient_trainer:
    247       if _is_uninitialized(self._model):
--> 248         self._model.init(self._batch_signature)
    249       self._eval_model.rng = self.new_rng()
    250       if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
    309       name, trace = self._name, _short_traceback(skip=3)
    310       raise LayerError(name, 'init', self._caller,
--> 311                        input_signature, trace) from None
    312 
    313   def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-159-f25c82633338>, line 63
  layer input shapes: (ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64})

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-157-d7d31ba3dd45>, line 54
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
  layer created in file [...]/<ipython-input-157-d7d31ba3dd45>, line 54
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

I solved it. The problem was with my dotproduct function. I had exchanged the indices.