When running the exercise 5 test:
# UNQ_C6 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# Testing
model = GRULM()
model.init_from_file('model.pkl.gz')
batch = next(data_generator(batch_size, max_length, lines, shuffle=False))
preds = model(batch[0])
log_ppx = test_model(preds, batch[1])
print('The log perplexity and perplexity of your model are respectively', log_ppx, np.exp(log_ppx))
I receive the following error, notice it happens before calling my test_model function:
---------------------------------------------------------------------------
LayerError Traceback (most recent call last)
<ipython-input-26-a43abee17328> in <module>
4 model.init_from_file('model.pkl.gz')
5 batch = next(data_generator(batch_size, max_length, lines, shuffle=False))
----> 6 preds = model(batch[0])
7 log_ppx = test_model(preds, batch[1])
8 print('The log perplexity and perplexity of your model are respectively', log_ppx, np.exp(log_ppx))
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
195 self.state = state # Needed if the model wasn't fully initialized.
196 state = self.state
--> 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
199 return outputs
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
604 name, trace = self._name, _short_traceback(skip=3)
605 raise LayerError(name, 'pure_fn',
--> 606 self._caller, signature(x), trace) from None
607
608 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-15-5d2dd83bfc8b>, line 21
layer input shapes: ShapeDtype{shape:(32, 64), dtype:int32}
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Dense_256 (in pure_fn):
layer created in file [...]/<ipython-input-15-5d2dd83bfc8b>, line 20
layer input shapes: ShapeDtype{shape:(32, 64, 512), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 96, in forward
return jnp.dot(x, w) + b # Affine map.
File [...]/_src/numpy/lax_numpy.py, line 4112, in dot
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
File [...]/_src/lax/lax.py, line 702, in dot_general
preferred_element_type=preferred_element_type)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/site-packages/jax/core.py, line 603, in process_primitive
return primitive.impl(*tracers, **params)
File [...]/jax/interpreters/xla.py, line 248, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File [...]/jax/_src/util.py, line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File [...]/jax/_src/util.py, line 179, in cached
return f(*args, **kwargs)
File [...]/jax/interpreters/xla.py, line 272, in xla_primitive_callable
aval_out = prim.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File [...]/_src/lax/lax.py, line 3391, in _dot_general_shape_rule
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [1024].
Can anyone please help to clarify what’s going on?