Problem validating Exercice 6 of C3_W1 (train_model function)

Hello,

I completed my my train_model function as follows :

training_loop = training.Loop( 
                                classifier, 
                                train_task, 
                                eval_tasks=eval_task, 
                                output_dir=output_dir, 
                                random_seed=31 
    )

So I dont get any compilation errors, however, during the validation step I get the following errors

Step      1: Total number of trainable weights: 2327042
Step      1: Ran 1 train steps in 1.64 secs
Step      1: train CrossEntropyLoss |  11.09774971
Step      1: eval  CrossEntropyLoss |  11.19594669
Step      1: eval          Accuracy |  8.00000000

Step     10: Ran 9 train steps in 6.50 secs
Step     10: train CrossEntropyLoss |  11.12563992
Step     10: eval  CrossEntropyLoss |  11.10262108
Step     10: eval          Accuracy |  8.00000000
**Wrong loss function. WeightedCategoryCrossEntropy_in3 was expected. Got CrossEntropyLoss_in3.**
**Wrong metrics in evaluations task. Expected ['WeightedCategoryCrossEntropy', 'WeightedCategoryAccuracy']. Got ['CrossEntropyLoss', 'Accuracy']**.
 **4  Tests passed**
** 2  Tests failed**

Do you have any idea about the reason of these validation errors. All previous functions of the assignement passed succesfully the unit tests.

(Posted in the other question that also reports errors in this question)

Hi,

I wanted to report that I got a notification saying that the notebook was changed and after using the new version and double-checking all the implementation code for this very same training step, I’m seeing some suspicious dimension mismatch. This is the error I’m getting when trying to train the model:

TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).

There are a bunch of cells that we are not supposed to touch and after double-checking all possible values I can’t seem to really find the issue. Can that be an issue with this latest notebook update?

Here is the full stack trace:

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-36-9216004d763d> in <module>
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-35-2d2291e65784> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
    426         device=device, backend=backend, name=flat_fun.__name__,
--> 427         donated_invars=donated_invars, inline=inline)
    428     out_pytree_def = out_tree()

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1559   def bind(self, fun, *args, **params):
-> 1560     return call_bind(self, fun, *args, **params)
   1561 

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1550   tracers = map(top_trace.full_raise, args)
-> 1551   outs = primitive.process(top_trace, fun, tracers, params)
   1552   return map(full_lower, apply_todos(env_trace_todo(), outs))

/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1562   def process(self, trace, fun, tracers, params):
-> 1563     return trace.process_call(self, fun, tracers, params)
   1564 

/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    605   def process_call(self, primitive, f, tracers, params):
--> 606     return primitive.impl(f, *tracers, **params)
    607   process_map = process_call

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
    592   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 593                                *unsafe_map(arg_spec, args))
    594   try:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    261     else:
--> 262       ans = call(fun, *args)
    263       cache[key] = (ans, fun.stores)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    667   jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 668       fun, abstract_args, pe.debug_info_final(fun, "jit"))
    669   if any(isinstance(c, core.Tracer) for c in consts):

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
   1283     with core.new_sublevel():
-> 1284       jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1285     del fun, main

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1261     in_tracers = map(trace.new_arg, in_avals)
-> 1262     ans = fun.call_wrapped(*in_tracers)
   1263     out_tracers = map(trace.full_raise, ans)

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    217       (loss, state), gradients = forward_and_backward_fn(
--> 218           batch, weights, state, rng)
    219       weights, slots, stats = optimizer.tree_update(

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in value_and_grad_f(*args, **kwargs)
    883     else:
--> 884       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
    885     _check_scalar(ans)

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in _vjp(fun, has_aux, *primals)
   1964     flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
-> 1965     out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
   1966     out_tree, aux_tree = out_aux_trees()

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    115   else:
--> 116     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
    117 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
    100   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
    102   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
    504     fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    506     assert not env

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 

UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 1033
  layer input shapes: (ShapeDtype{shape:(16, 26), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
  layer created in file [...]/<ipython-input-32-e22a181c30d5>, line 12
  layer input shapes: (ShapeDtype{shape:(16, 2), dtype:float32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/metrics.py, line 273, in f
    model_output, targets, label_smoothing)

  File [...]/trax/layers/metrics.py, line 649, in _category_cross_entropy
    return - jnp.sum(target_distributions * model_log_distributions, axis=-1)

  File [...]/site-packages/jax/core.py, line 506, in __mul__
    def __mul__(self, other): return self.aval._mul(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
    return binary_op(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 431, in fn
    return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

  File [...]/_src/lax/lax.py, line 348, in mul
    return mul_p.bind(x, y)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 274, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

LayerError                                Traceback (most recent call last)
<ipython-input-36-9216004d763d> in <module>
      2 # Take a look on how the eval_task is inside square brackets and
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-35-2d2291e65784> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     23     ) 
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###
     27 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    433           loss_acc, step_acc = 0.0, 0
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 
    437         # optimizer_metrics and loss are replicated on self.n_devices, a few

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    631 
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )
    635 

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    146     # NOTE: stats is a replicated dictionary of key to jnp arrays.
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 
    150     if logging.vlog_is_on(1) and ((step & step - 1) == 0):

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    216       weights, slots = weights_and_slots
    217       (loss, state), gradients = forward_and_backward_fn(
--> 218           batch, weights, state, rng)
    219       weights, slots, stats = optimizer.tree_update(
    220           step, gradients, weights, slots, opt_params, store_slots=False)

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    604       name, trace = self._name, _short_traceback(skip=3)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 
    608   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 1033
  layer input shapes: (ShapeDtype{shape:(16, 26), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
  layer created in file [...]/<ipython-input-32-e22a181c30d5>, line 12
  layer input shapes: (ShapeDtype{shape:(16, 2), dtype:float32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/metrics.py, line 273, in f
    model_output, targets, label_smoothing)

  File [...]/trax/layers/metrics.py, line 649, in _category_cross_entropy
    return - jnp.sum(target_distributions * model_log_distributions, axis=-1)

  File [...]/site-packages/jax/core.py, line 506, in __mul__
    def __mul__(self, other): return self.aval._mul(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
    return binary_op(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 431, in fn
    return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

  File [...]/_src/lax/lax.py, line 348, in mul
    return mul_p.bind(x, y)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 274, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).

Hi,

I had similar problem, but I fixed by change a error in the Classifer model. When construct the Serial model, the Mean layer should have only a parameter of axis=1.

Hope this could help you.

Cheers,
Xiao

3 Likes

Thanks so much! I was also getting a similar error and adding axis=1 in the definition of the Mean layer solved it!

1 Like

I was getting the following error
The following axis names are available to collective operations
This is sorted after I put the axis=1 instead of embed_layer