Hello!
I’m currently doing the week1 assignment of course 4 and the following error happens when I run the cell to load the pre-trained model. I passed all the tests and managed to train my own model with no problems. Does anyone else encountered this issue? Thank you!
LayerError Traceback (most recent call last)
in
3
4 # initialize weights from a pre-trained model
----> 5 model.init_from_file(“model.pkl.gz”, weights_only=True)
6 model = tl.Accelerate(model)
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init_from_file(self, file_name, weights_only, input_signature)
347 input_signature = dictionary[‘input_signature’]
348 if weights_only and input_signature is not None:
→ 349 self.init(input_signature)
350 weights_and_state_sig = self.weights_and_state_signature(input_signature)
351 weights, state = unflatten_weights_and_state(
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):
LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 64
layer input shapes: (ShapeDtype{shape:(1, 1), dtype:<class ‘numpy.int32’>}, ShapeDtype{shape:(1, 1), dtype:<class ‘numpy.int32’>}, ShapeDtype{shape:(1, 1), dtype:<class ‘numpy.int32’>})
File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer PrepareAttentionInput (in _forward_abstract):
layer created in file […]/, line 48
layer input shapes: (ShapeDtype{shape:(1, 1, 1024), dtype:float32}, ShapeDtype{shape:(1, 1, 1024), dtype:float32}, ShapeDtype{shape:(1, 1), dtype:int32})
File […]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)
File […]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File […]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer PrepareAttentionInput (in pure_fn):
layer created in file […]/, line 48
layer input shapes: (ShapeDtype{shape:(1, 1, 1024), dtype:float32}, ShapeDtype{shape:(1, 1, 1024), dtype:float32}, ShapeDtype{shape:(1, 1), dtype:int32})
File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)
File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)
File […]/, line 31, in prepare_attention_input
mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
IndexError: tuple index out of range