UNQ_C8, training_loop error

Hi, I have checked all the threads related to unq_8 but mine was bit different.
Here is my Lab ID : noyqifnk
Below is the code of training loop :

### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
train_task = training.TrainTask( 
  labeled_data=train_gen, # The training generator
  loss_layer=tl.CrossEntropyLoss(), # Loss function (Don't forget to instantiate!)
  optimizer=trax.optimizers.Adam(0.01), # Optimizer (Don't forget to set LR to 0.01)
  lr_schedule=lr_schedule,
  n_steps_per_checkpoint=10 
)

eval_task = training.EvalTask( 
  labeled_data=eval_gen, # The evaluation generator
  metrics=[tl.CrossEntropyLoss(), tl.Accuracy()] # CrossEntropyLoss and Accuracy (Don't forget to instantiate both!)
)

### END CODE HERE ###

And Error :

LayerError Traceback (most recent call last)
in
1 # UNIT TEST
2 # test training_loop
----> 3 w2_tests.test_training_loop(training_loop, TransformerLM)

~/work/w2_tests.py in test_training_loop(target, TransformerLM)
813 os.remove(“~/model/model.pkl.gz”)
814
→ 815 output_loop = target(TransformerLM, my_gen(), my_gen())
816
817 try:

in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
40 train_task,
41 eval_tasks=[eval_task],
—> 42 output_dir=output_dir)
43
44 return loop

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 62
layer input shapes: (ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64})

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Parallel (in init):
layer created in file […]/, line 54
layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})

File […]/trax/layers/combinators.py, line 226, in init_weights_and_state
in zip(self.sublayers, sublayer_signatures)]

File […]/trax/layers/combinators.py, line 225, in
for layer, signature

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 47
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
layer created in file […]/, line 39
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Parallel (in init):
layer created in file […]/, line 39
layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})

File […]/trax/layers/combinators.py, line 226, in init_weights_and_state
in zip(self.sublayers, sublayer_signatures)]

File […]/trax/layers/combinators.py, line 225, in
for layer, signature

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 39
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer AttnHeads (in _forward_abstract):
layer created in file […]/, line 32
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

File […]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer AttnHeads (in pure_fn):
layer created in file […]/, line 32
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)

File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)

File […]/, line 36, in compute_attention_heads
x = jnp.reshape(x, (-1, n_heads * batch_size,d_head))

File […]/_src/numpy/lax_numpy.py, line 1711, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays

File […]/_src/numpy/lax_numpy.py, line 1729, in _reshape
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)

File […]/_src/numpy/lax_numpy.py, line 1725, in _compute_newshape
for d in newshape)

File […]/_src/numpy/lax_numpy.py, line 1725, in
for d in newshape)

File […]/site-packages/jax/core.py, line 1409, in divide_shape_sizes
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])

File […]/site-packages/jax/core.py, line 1323, in divide_shape_sizes
if sz1 % sz2:

ZeroDivisionError: integer division or modulo by zero

@balaji.ambresh @arvyzukai Can you please help in here ?

Please click my name and message your notebook as an attachment.

compute_attention_heads has 2 bugs:

  1. When computing batch_size, why are you performing additional calculations? Please take a look at the shape of x and get the batch size.
  2. The shape specified to the last reshape operation is incorrect.

Hope this helps.

Got it, Thank you @balaji.ambresh :slight_smile:

@Joao_Victor_Pereira

Your implementation of CausalAttention is incorrect with respect to invoking compute_attention_output_closure. Please fix that to pass the tests.

1 Like

I didn’t notice I was passing the wrong parameter for d_head for both compute_attention_output_closure and compute_attention_heads_closure. Thank you :grin: