Exercise 5 - UNQ_C6 error

When running the exercise 5 test:

# Testing 
model = GRULM()
batch = next(data_generator(batch_size, max_length, lines, shuffle=False))
preds = model(batch[0])
log_ppx = test_model(preds, batch[1])
print('The log perplexity and perplexity of your model are respectively', log_ppx, np.exp(log_ppx))

I receive the following error, notice it happens before calling my test_model function:

LayerError                                Traceback (most recent call last)
<ipython-input-26-a43abee17328> in <module>
      4 model.init_from_file('model.pkl.gz')
      5 batch = next(data_generator(batch_size, max_length, lines, shuffle=False))
----> 6 preds = model(batch[0])
      7 log_ppx = test_model(preds, batch[1])
      8 print('The log perplexity and perplexity of your model are respectively', log_ppx, np.exp(log_ppx))

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
    195       self.state = state  # Needed if the model wasn't fully initialized.
    196     state = self.state
--> 197     outputs, new_state = self.pure_fn(x, weights, state, rng)
    198     self.state = new_state
    199     return outputs

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    604       name, trace = self._name, _short_traceback(skip=3)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    608   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-15-5d2dd83bfc8b>, line 21
  layer input shapes: ShapeDtype{shape:(32, 64), dtype:int32}

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Dense_256 (in pure_fn):
  layer created in file [...]/<ipython-input-15-5d2dd83bfc8b>, line 20
  layer input shapes: ShapeDtype{shape:(32, 64, 512), dtype:float32}

  File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
    y = forward(self, x, *args, **kwargs)

  File [...]/trax/layers/core.py, line 96, in forward
    return jnp.dot(x, w) + b  # Affine map.

  File [...]/_src/numpy/lax_numpy.py, line 4112, in dot
    return lax.dot_general(a, b, (contract_dims, batch_dims), precision)

  File [...]/_src/lax/lax.py, line 702, in dot_general

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/site-packages/jax/core.py, line 603, in process_primitive
    return primitive.impl(*tracers, **params)

  File [...]/jax/interpreters/xla.py, line 248, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)

  File [...]/jax/_src/util.py, line 186, in wrapper
    return cached(config._trace_context(), *args, **kwargs)

  File [...]/jax/_src/util.py, line 179, in cached
    return f(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 272, in xla_primitive_callable
    aval_out = prim.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 3391, in _dot_general_shape_rule
    raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))

TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [1024].

Can anyone please help to clarify what’s going on?

Hey @elisio,
Welcome to the community. I believe this error indicates that there is an error in your implementation of either the GRULM or the data_generator function. Can you please confirm if you are passing the test-cases and getting the expected outputs corresponding to both of these functions? If yes, in that case, please DM me your implementation of both the functions.

For DM, click on my name and select “Message”.


Hey @elisio,
There are 2 errors in your implementation of GRULM model. The first error lies in ShiftRight. Although, we do have to specify the mode, but the mode is passed in as a function parameter, and we have to use that, so that as per the scenario in-case, we can change the argument for this parameter to use the model in different modes.

The second error lies in the way you have implemented the stack of GRU layers. Check this highly relevant thread to your discussion, and you will get to know how to fix this issue.