Error when running the training_loop (UNQ_C8)

When running the training loop

# Should take around 1.5 minutes
!rm -f ~/model/model.pkl.gz
loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream)

I keep getting the following error.

LayerError                                Traceback (most recent call last)
<ipython-input-42-bf4e14290aae> in <module>
      1 # Should take around 1.5 minutes
      2 get_ipython().system('rm -f ~/model/model.pkl.gz')
----> 3 loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream)

<ipython-input-40-c6ec63e508e6> in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
     40                          train_task,
     41                          eval_tasks=[eval_task],
---> 42                          output_dir=output_dir)
     44     return loop

/opt/conda/lib/python3.7/site-packages/trax/supervised/ in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    246     if not use_memory_efficient_trainer:
    247       if _is_uninitialized(self._model):
--> 248         self._model.init(self._batch_signature)
    249       self._eval_model.rng = self.new_rng()
    250       if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/ in init(self, input_signature, rng, use_cache)
    309       name, trace = self._name, _short_traceback(skip=3)
    310       raise LayerError(name, 'init', self._caller,
--> 311                        input_signature, trace) from None
    313   def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-37-6265e024f03d>, line 63
  layer input shapes: (ShapeDtype{shape:(2, 1024), dtype:int64}, ShapeDtype{shape:(2, 1024), dtype:int64}, ShapeDtype{shape:(2, 1024), dtype:int64})

  File [...]/trax/layers/, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-34-a2ac93ee377e>, line 54
  layer input shapes: ShapeDtype{shape:(2, 1024, 4), dtype:float32}

  File [...]/trax/layers/, line 106, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer Add (in _forward_abstract):
  layer created in file [...]/<ipython-input-34-a2ac93ee377e>, line 54
  layer input shapes: (ShapeDtype{shape:(2, 1024, 4), dtype:float32}, ShapeDtype{shape:(4, 1024, 4), dtype:float32})

  File [...]/jax/interpreters/, line 404, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals, transform_name)

  File [...]/jax/interpreters/, line 1178, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/, line 1188, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/site-packages/jax/, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer Add (in pure_fn):
  layer created in file [...]/<ipython-input-34-a2ac93ee377e>, line 54
  layer input shapes: (ShapeDtype{shape:(2, 1024, 4), dtype:float32}, ShapeDtype{shape:(4, 1024, 4), dtype:float32})

  File [...]/trax/layers/, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/, line 843, in <lambda>
    return Fn('Add', lambda x0, x1: x0 + x1)

  File [...]/site-packages/jax/, line 505, in __add__
    def __add__(self, other): return self.aval._add(self, other)

  File [...]/_src/numpy/, line 5666, in deferring_binary_op
    return binary_op(self, other)

  File [...]/_src/numpy/, line 427, in fn
    return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

  File [...]/_src/lax/, line 340, in add
    return add_p.bind(x, y)

  File [...]/site-packages/jax/, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/, line 1049, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/, line 2115, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/, line 2211, in _broadcasting_shape_rule
    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

TypeError: add got incompatible shapes for broadcasting: (2, 1024, 4), (4, 1024, 4).

My code has passed all previous unit tests.

Could someone help me? Thanks!


1 Like

Hi @sanderling,

Can you share your lab ID with me ? In the assignment, when you click the top right “Help” button, a panel will open and your lab ID will be shown at the bottom.

I shall take a look.

When you reply back, kindly tag me in the post so that I’m notified.


Did you get an answer? I have a similar problem.

I got similar problem. My mistake was i was calling compute_attention_heads_closure instead of (Code snippet removed) at tl.Fn(‘AttnOutput’, at # UNQ_C5

Same problem here. I got everything correct.
My lab ID is filjrqwr

I discovered that I didn’t use (Code snippet removed) (calculated by //) in the casual attention but put d_feature/n_heads directly into the calling.

Now it works.


After a long time of checking, what fixed mine is this line in # UNQ_C1: inside the jnp.where, instead of np.full_like() , it should be (Code snippet removed). (I can’t find out now if this is provided code’s error or I accidentally deleted the ‘j’ in the given code.)

I had the same bug… Check all your trax call definitions… My problem was calling:
tl.Dropout(dropout, mode) in several places, but mode is the 3rd arg, so updating to (Code snippet removed) in several places above fixed it for me.
Good luck!

same. I hadn’t used the precalculated d_feature // n_heads when computing attention both times in UNQ_C5. looking at the parameters that both of the compute_attention_X functions take helped me figure it out.

Turns out I made the same mistake. It would be helpful if the course’s unit tests could be written to explicitly catch these sorts of issues earlier in the assignment.

I tried all your solutions and still it does not work. Can anyone help?


Hi oscar-defelice,

If you send me your notebook as an attachment in a PM I can have a look.

Same as oscar, I have checked the previous mentioned errors but I was already doing it the correct way (I think), and cannot find where the error lies as I passed all the previous tests (moreover grader thinks code is good, but when trying to run it it won’t work)

1 Like

Hi Kezrael,

If you send me your notebook as an attachment in a PM I can have a look.

Hi Kezrael,

There’s a problem with your TransformerLM function. You cannot pass a tl.Relu layer as an argument to a tl.Dense layer.

The “Note: activation already set by ff_activation” comment may have confused you. From what I gather from the logic, no activation is needed after the final dense layer as its output passes directly into a log softmax layer.


Thanks! Yes, that statement confused me and I thought I needed to specify activation as such.

I’m also having this issue.

1 Like