I passed all previous 4 tests but I got an error in Exercise 5. The error seems to be related to the model in Exercise 4. There is a shape issue that I have not figured out. Help please!
LayerError Traceback (most recent call last)
in
4 model.init_from_file(‘model.pkl.gz’)
5 batch = next(data_generator(batch_size, max_length, lines, shuffle=False))
----> 6 preds = model(batch[0])
7 log_ppx = test_model(preds, batch[1])
8 print(‘The log perplexity and perplexity of your model are respectively’, log_ppx, np.exp(log_ppx))
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in call(self, x, weights, state, rng)
195 self.state = state # Needed if the model wasn’t fully initialized.
196 state = self.state
→ 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
199 return outputs
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
604 name, trace = self._name, _short_traceback(skip=3)
605 raise LayerError(name, ‘pure_fn’,
→ 606 self._caller, signature(x), trace) from None
607
608 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file […]/, line 21
layer input shapes: ShapeDtype{shape:(32, 64), dtype:int32}
File […]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Dense_256 (in pure_fn):
layer created in file […]/, line 20
layer input shapes: ShapeDtype{shape:(32, 64, 512), dtype:float32}
File […]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File […]/trax/layers/core.py, line 96, in forward
return jnp.dot(x, w) + b # Affine map.
File […]/_src/numpy/lax_numpy.py, line 4112, in dot
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
File […]/_src/lax/lax.py, line 702, in dot_general
preferred_element_type=preferred_element_type)
File […]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File […]/site-packages/jax/core.py, line 603, in process_primitive
return primitive.impl(*tracers, **params)
File […]/jax/interpreters/xla.py, line 248, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File […]/jax/_src/util.py, line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File […]/jax/_src/util.py, line 179, in cached
return f(*args, **kwargs)
File […]/jax/interpreters/xla.py, line 272, in xla_primitive_callable
aval_out = prim.abstract_eval(*avals, **params)
File […]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File […]/_src/lax/lax.py, line 3391, in _dot_general_shape_rule
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [1024].