Running into error with C4_W2_Assignment for Attention Models

In exercise 5, when the Training Loop is being run, I get this problem, when running the unit test. So I executed some of the subsequent steps where it loads a pre trained model. Got the same error with that too. Please help.


LayerError Traceback (most recent call last)
in
1 # UNIT TEST
2 # test training_loop
----> 3 w2_tests.test_training_loop(training_loop, TransformerLM)

~/work/w2_tests.py in test_training_loop(target, TransformerLM)
813 os.remove(β€œ~/model/model.pkl.gz”)
814
β†’ 815 output_loop = target(TransformerLM, my_gen(), my_gen())
816
817 try:

in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
41 train_task,
42 eval_tasks=[eval_task],
β€”> 43 output_dir=output_dir)
44
45 return loop

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
β†’ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, β€˜init’, self._caller,
β†’ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 63
layer input shapes: (ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64})

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Parallel (in init):
layer created in file […]/, line 54
layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})

File […]/trax/layers/combinators.py, line 226, in init_weights_and_state
in zip(self.sublayers, sublayer_signatures)]

File […]/trax/layers/combinators.py, line 225, in
for layer, signature

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 49
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer AttnOutput (in _forward_abstract):
layer created in file […]/, line 48
layer input shapes: ShapeDtype{shape:(2, 1250, 2), dtype:float32}

File […]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

File […]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer AttnOutput (in pure_fn):
layer created in file […]/, line 48
layer input shapes: ShapeDtype{shape:(2, 1250, 2), dtype:float32}

File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)

File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)

File […]/, line 26, in compute_attention_output
x = jnp.reshape(x, (x.shape[-1], n_heads, seqlen, d_head))

File […]/_src/numpy/lax_numpy.py, line 1711, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays

File […]/_src/numpy/lax_numpy.py, line 1731, in _reshape
return lax.reshape(a, newshape, None)

File […]/_src/lax/lax.py, line 832, in reshape
dimensions=None if dimensions is None or same_dims else tuple(dimensions))

File […]/site-packages/jax/core.py, line 272, in bind
out = top_trace.process_primitive(self, tracers, params)

File […]/jax/interpreters/partial_eval.py, line 1317, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)

File […]/_src/lax/lax.py, line 2274, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

File […]/_src/lax/lax.py, line 4096, in _reshape_shape_rule
if not core.same_shape_sizes(np.shape(operand), new_sizes):

File […]/site-packages/jax/core.py, line 1412, in same_shape_sizes
return 1 == divide_shape_sizes(s1, s2)

File […]/site-packages/jax/core.py, line 1409, in divide_shape_sizes
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])

File […]/site-packages/jax/core.py, line 1324, in divide_shape_sizes
raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")

jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (2, 1250, 2) and (2, 2, 1250, 2)

Hi @Ranita_Pal

The error indicates the shape miss match. Make sure that:

  • in UNQ_C5 you set the right dimensions - d_feature, d_head where appropriate;
  • make sure, that the expected outputs for exercises # UNQ_C2 and # UNQ_C4 match with your outputs.

Let me know if that solves your problems.

Cheers

All unit tests have passed until the one for # UNQ_C8. The output matches the expected output. I have checked the dimensions again, but I am unable to understand what needs to be changed here.

The most probable cause for getting these wrong shapes is in UNQ_C4:

        # Reshape x using jnp.reshape() to shape (n_batch, n_heads, seqlen, d_head)
        # Use '-1' for `n_batch` in this case
        # ... your solution

Do not use x.shape[-1] here, because x is a tensor with shape (n_batch X n_heads, seqlen, d_head) and x.shape[-1] results in d_head which is not what is asked from you and not the correct dimension for the reshape.

Cheers

Thanks, I understand now. That stage works now. It could be better if the unit tests catch it at that point itself, instead of two steps down the line. Also when I tried to give n_batch as x.shape[0]/n_heads, that ran into error as well.

Yes, this is the reason for code comment because it fails some unit test.

I agree. For now these tests are not perfect but at least sometimes helpful :slight_smile:

Cheers

Thank you so much for your help!

1 Like

Thanks you guys! Had the same problem. This thread helped!

It seem. I had problem in all cases.

  • in UNQ_C5 you set the right dimensions - d_feature, d_head where appropriate;
  • make sure, that the expected outputs for exercises # UNQ_C2 and # UNQ_C4 match with your outputs.
    Now it works fine, Thanks for this post