I can't find what problem is in UNQ_C5 & C8

Hi~
I am stuck on two problems that I can’t solve. I would appreciate any help you can offer. :grinning:

[CODE REMOVED FROM MODERATOR]


**<problem2>**
When I run unit tests with the code written in UNQ_C8, the following error occurs:
TypeError: reshape total size must be unchanged, got new_sizes (1, 1250, 2, 2) for shape (1, 1250, 512).

full error message

LayerError Traceback (most recent call last)
in
1 # UNIT TEST
2 # test training_loop
----> 3 w2_tests.test_training_loop(training_loop, TransformerLM)

~/work/w2_tests.py in test_training_loop(target, TransformerLM)
813 os.remove(“~/model/model.pkl.gz”)
814
→ 815 output_loop = target(TransformerLM, my_gen(), my_gen())
816
817 try:

in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
41 train_task,
42 eval_tasks=[eval_task],
—> 43 output_dir=output_dir)
44
45 return loop

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 63
layer input shapes: (ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64})

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Parallel (in init):
layer created in file […]/, line 54
layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})

File […]/trax/layers/combinators.py, line 226, in init_weights_and_state
in zip(self.sublayers, sublayer_signatures)]

File […]/trax/layers/combinators.py, line 225, in
for layer, signature

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 49
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
layer created in file […]/, line 40
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Parallel (in init):
layer created in file […]/, line 40
layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})

File […]/trax/layers/combinators.py, line 226, in init_weights_and_state
in zip(self.sublayers, sublayer_signatures)]

File […]/trax/layers/combinators.py, line 225, in
for layer, signature

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 40
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer AttnHeads (in _forward_abstract):
layer created in file […]/, line 33
layer input shapes: ShapeDtype{shape:(1, 1250, 512), dtype:float32}

File […]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

File […]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer AttnHeads (in pure_fn):
layer created in file […]/, line 33
layer input shapes: ShapeDtype{shape:(1, 1250, 512), dtype:float32}

File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)

File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)

File […]/, line 29, in compute_attention_heads
x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))

File […]/_src/numpy/lax_numpy.py, line 1711, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays

File […]/_src/numpy/lax_numpy.py, line 1731, in _reshape
return lax.reshape(a, newshape, None)

File […]/_src/lax/lax.py, line 832, in reshape
dimensions=None if dimensions is None or same_dims else tuple(dimensions))

File […]/site-packages/jax/core.py, line 272, in bind
out = top_trace.process_primitive(self, tracers, params)

File […]/jax/interpreters/partial_eval.py, line 1317, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)

File […]/_src/lax/lax.py, line 2274, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

File […]/_src/lax/lax.py, line 4098, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, np.shape(operand)))

TypeError: reshape total size must be unchanged, got new_sizes (1, 1250, 2, 2) for shape (1, 1250, 512).

Hello, you must not post code solutions here, its against code of conduct, it seems that you have hard coded the d_feature in the solution…

I didn’t thought about that :hushed:
Thank you for letting me know that!