Training Loop (Again!) in UNQ_C8

I know that numerous others got stuck on UNQ_C8 regarding a cryptic error related to incompatibility of tensor sizes, but as none of the existing open topics solves this issue, I will need to open this again and hope that someone will provide a solution.

In a nutshell, I am passing all the tests up to UNQ_C8, but running into the following error in the training loop.

Please advise. Thanks in advance.


LayerError Traceback (most recent call last)
in
1 # UNIT TEST
2 # test training_loop
----> 3 w2_tests.test_training_loop(training_loop, TransformerLM)

~/work/w2_tests.py in test_training_loop(target, TransformerLM)
813 os.remove(“~/model/model.pkl.gz”)
814
→ 815 output_loop = target(TransformerLM, my_gen(), my_gen())
816
817 try:

in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
40 train_task,
41 eval_tasks=[eval_task],
—> 42 output_dir=output_dir)
43
44 return loop

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 63
layer input shapes: (ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64})

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Parallel (in init):
layer created in file […]/, line 54
layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})

File […]/trax/layers/combinators.py, line 226, in init_weights_and_state
in zip(self.sublayers, sublayer_signatures)]

File […]/trax/layers/combinators.py, line 225, in
for layer, signature

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 47
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer DotProductAttn (in _forward_abstract):
layer created in file […]/, line 42
layer input shapes: (ShapeDtype{shape:(2, 1250, 2), dtype:float32}, ShapeDtype{shape:(2, 1250, 2), dtype:float32}, ShapeDtype{shape:(2, 1250, 2), dtype:float32})

File […]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

File […]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer DotProductAttn (in pure_fn):
layer created in file […]/, line 42
layer input shapes: (ShapeDtype{shape:(2, 1250, 2), dtype:float32}, ShapeDtype{shape:(2, 1250, 2), dtype:float32}, ShapeDtype{shape:(2, 1250, 2), dtype:float32})

File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)

File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)

File […]/, line 26, in dot_product_self_attention
return DotProductAttention(q, k, v, mask)

File […]/, line 43, in DotProductAttention
dots = jnp.exp(dots-logsumexp.T)

File […]/site-packages/jax/core.py, line 518, in sub
def sub(self, other): return self.aval._sub(self, other)

File […]/_src/numpy/lax_numpy.py, line 6585, in deferring_binary_op
return binary_op(self, other)

File […]/_src/numpy/lax_numpy.py, line 679, in
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.name, x1, x2))

File […]/_src/numpy/lax_numpy.py, line 581, in _promote_args
return _promote_shapes(fun_name, *_promote_dtypes(*args))

File […]/_src/numpy/lax_numpy.py, line 500, in _promote_shapes
result_rank = len(lax.broadcast_shapes(*shapes))

ValueError: Incompatible shapes for broadcasting: ((2, 1250, 1250), (1, 1250, 2))

1 Like

Hi @Davit_Khachatryan

Most of the previous issues were solved by using local variables instead of global ones. Make sure that you do the same. If the problem is still there, you can Private Message me your notebook and I will take a look.

Cheers

@arvyzukai ,

Thank you for the feedback. I am a little confused about your note regarding global vs local variables, so I will be grateful to you if you could please take a look at my notebook. The ID is wkydroui. I look forward to hearing back from you, and thanks again!

Hi @Davit_Khachatryan

A quick google search suggested me this simple explanation. I am sure there are better examples but you could start here.

In plain words - do not unnecessary use a variable, that is outside the function. For example, if you have:

def Siamese(vocab_size=41699, d_model=128, mode='train'):
    ...

do not use len(Vocab) variable, use - vocab_size.

P.S. I cannot access your Assignment notebook by ID. You can private message me it if you still have problems.

Cheers

@arvyzukai,

I know what global/local variables mean, but thanks for the resources in any case. I do not use any global variables within the functions anyways. Any further guidance will be appreciated.

P.S. Does tagging your name at the beginning of the response serve as a “private message” on this platform?

Just click on my username, then in the pop up window click on “Message”, then attach your notebook (how to download) to the message .

Similar issue, have PMd you :slight_smile:

I’m having the exact same issue. Can someone please help?

Hi @Rama_Mahajanam

I private messaged you suggestions for your code.

For future readers - do not import additional libraries when you are not asked to do that. And in general, please check these points.

Good luck!

how do I see the personal DM?