C4W1 Error when loading pre-trained model

Hello!

I’m currently doing the week1 assignment of course 4 and the following error happens when I run the cell to load the pre-trained model. I passed all the tests and managed to train my own model with no problems. Does anyone else encountered this issue? Thank you!


LayerError Traceback (most recent call last)
in
3
4 # initialize weights from a pre-trained model
----> 5 model.init_from_file(“model.pkl.gz”, weights_only=True)
6 model = tl.Accelerate(model)

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init_from_file(self, file_name, weights_only, input_signature)
347 input_signature = dictionary[‘input_signature’]
348 if weights_only and input_signature is not None:
→ 349 self.init(input_signature)
350 weights_and_state_sig = self.weights_and_state_signature(input_signature)
351 weights, state = unflatten_weights_and_state(

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 64
layer input shapes: (ShapeDtype{shape:(1, 1), dtype:<class ‘numpy.int32’>}, ShapeDtype{shape:(1, 1), dtype:<class ‘numpy.int32’>}, ShapeDtype{shape:(1, 1), dtype:<class ‘numpy.int32’>})

File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer PrepareAttentionInput (in _forward_abstract):
layer created in file […]/, line 48
layer input shapes: (ShapeDtype{shape:(1, 1, 1024), dtype:float32}, ShapeDtype{shape:(1, 1, 1024), dtype:float32}, ShapeDtype{shape:(1, 1), dtype:int32})

File […]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

File […]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer PrepareAttentionInput (in pure_fn):
layer created in file […]/, line 48
layer input shapes: (ShapeDtype{shape:(1, 1, 1024), dtype:float32}, ShapeDtype{shape:(1, 1, 1024), dtype:float32}, ShapeDtype{shape:(1, 1), dtype:int32})

File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)

File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)

File […]/, line 31, in prepare_attention_input
mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))

IndexError: tuple index out of range

Hi ericakido,

It looks like there’s a problem with your implementation of NMTAttn, as the weights do not seem to fit the model. Have you been able to resolve this yet? If not, feel free to send me your notebook as an attachment to a direct mail.

Hi, reinoudbosch! Thank you for the reply.

I managed to solve the issue. I don’t quite recall what was the problem, but it was a very silly mistake from my part xD Thank you!