Error on C3W1 Exercise 06 Coding Assignment

I am getting ‘TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2)’ for exercise 6 in this coding assignment. All previous exercises are passing all test cases, and my axes are also correctly defined. I have been trying to figure out the problem without luck. My lab id is omyayqotxydh and I would appreciate some help - thank you in advance!

Hi @Brainism

Have you solved your error? If you still have trouble, you can private message me your notebook attached and I can try to help. (Mentors cannot access notebooks by ids).

Cheers

@Brainism Did you get this resolved? I had the same issue, and could not figure out what went wrong.

Hey @Elemento

I have the same error as OP with Exercise 6. Everything up to this passes unit tests, but this exercise has been really tricky. I’m assuming there is something wrong with my classifier in exercise 5

Would appreciate a nudge in the right direction. Below is my errorstack. Apologies in advance if I am adding the traceback incorrectly.

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-46-9216004d763d> in <module>
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-45-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
    426         device=device, backend=backend, name=flat_fun.__name__,
--> 427         donated_invars=donated_invars, inline=inline)
    428     out_pytree_def = out_tree()

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1559   def bind(self, fun, *args, **params):
-> 1560     return call_bind(self, fun, *args, **params)
   1561 

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1550   tracers = map(top_trace.full_raise, args)
-> 1551   outs = primitive.process(top_trace, fun, tracers, params)
   1552   return map(full_lower, apply_todos(env_trace_todo(), outs))

/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1562   def process(self, trace, fun, tracers, params):
-> 1563     return trace.process_call(self, fun, tracers, params)
   1564 

/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    605   def process_call(self, primitive, f, tracers, params):
--> 606     return primitive.impl(f, *tracers, **params)
    607   process_map = process_call

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
    592   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 593                                *unsafe_map(arg_spec, args))
    594   try:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    261     else:
--> 262       ans = call(fun, *args)
    263       cache[key] = (ans, fun.stores)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    667   jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 668       fun, abstract_args, pe.debug_info_final(fun, "jit"))
    669   if any(isinstance(c, core.Tracer) for c in consts):

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
   1283     with core.new_sublevel():
-> 1284       jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1285     del fun, main

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1261     in_tracers = map(trace.new_arg, in_avals)
-> 1262     ans = fun.call_wrapped(*in_tracers)
   1263     out_tracers = map(trace.full_raise, ans)

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    217       (loss, state), gradients = forward_and_backward_fn(
--> 218           batch, weights, state, rng)
    219       weights, slots, stats = optimizer.tree_update(

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in value_and_grad_f(*args, **kwargs)
    883     else:
--> 884       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
    885     _check_scalar(ans)

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in _vjp(fun, has_aux, *primals)
   1964     flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
-> 1965     out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
   1966     out_tree, aux_tree = out_aux_trees()

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    115   else:
--> 116     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
    117 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
    100   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
    102   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
    504     fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    506     assert not env

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 

UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 1033
  layer input shapes: (ShapeDtype{shape:(16, 26), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int8})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
  layer created in file [...]/<ipython-input-42-e22a181c30d5>, line 12
  layer input shapes: (ShapeDtype{shape:(16, 2), dtype:float32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int8})

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/metrics.py, line 273, in f
    model_output, targets, label_smoothing)

  File [...]/trax/layers/metrics.py, line 649, in _category_cross_entropy
    return - jnp.sum(target_distributions * model_log_distributions, axis=-1)

  File [...]/site-packages/jax/core.py, line 506, in __mul__
    def __mul__(self, other): return self.aval._mul(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
    return binary_op(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 431, in fn
    return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

  File [...]/_src/lax/lax.py, line 348, in mul
    return mul_p.bind(x, y)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 274, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

LayerError                                Traceback (most recent call last)
<ipython-input-46-9216004d763d> in <module>
      2 # Take a look on how the eval_task is inside square brackets and
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-45-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     23     ) 
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###
     27 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    433           loss_acc, step_acc = 0.0, 0
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 
    437         # optimizer_metrics and loss are replicated on self.n_devices, a few

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    631 
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )
    635 

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    146     # NOTE: stats is a replicated dictionary of key to jnp arrays.
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 
    150     if logging.vlog_is_on(1) and ((step & step - 1) == 0):

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    216       weights, slots = weights_and_slots
    217       (loss, state), gradients = forward_and_backward_fn(
--> 218           batch, weights, state, rng)
    219       weights, slots, stats = optimizer.tree_update(
    220           step, gradients, weights, slots, opt_params, store_slots=False)

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    604       name, trace = self._name, _short_traceback(skip=3)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 
    608   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 1033
  layer input shapes: (ShapeDtype{shape:(16, 26), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int8})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
  layer created in file [...]/<ipython-input-42-e22a181c30d5>, line 12
  layer input shapes: (ShapeDtype{shape:(16, 2), dtype:float32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int8})

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/metrics.py, line 273, in f
    model_output, targets, label_smoothing)

  File [...]/trax/layers/metrics.py, line 649, in _category_cross_entropy
    return - jnp.sum(target_distributions * model_log_distributions, axis=-1)

  File [...]/site-packages/jax/core.py, line 506, in __mul__
    def __mul__(self, other): return self.aval._mul(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
    return binary_op(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 431, in fn
    return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

  File [...]/_src/lax/lax.py, line 348, in mul
    return mul_p.bind(x, y)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 274, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).

Hey @krisoye,
Well, one of the issues I can see clearly in your error stack is that you have used a global variable, namely model in your implementation. You are supposed to use the parameter of the function, i.e., classifier instead. Unless and until, you are explicitly told to use a global variable, ensure that you haven’t used a global variable in your implementations. Let me know if this helps.

Cheers,
Elemento

Hey @Elemento,

I’m not following. I don’t see a global reference to model in my train_model function. Is this the function that you were referring to? Would you like me to pm you my notebook?

Hey @krisoye,
I am referring to the above line of code. In this, you are supposed to use the parameter of the function classifier instead of the global variable model. Note that I have highlighted the variables like classifier and model for easier distinction. Let me know if this helps.

Cheers,
Elemento

Hi,
I am having a similar issue. I tried modifying the axis in tl.Mean() like people in different threads suggested, but the issue persists (although the mismatching dimensions do change).
TBH I am not sure what the suggested solution in the last post is, that cell is read-only.

Hey @guszejnov,
Welcome, and we are glad that you could become a part of our community :partying_face:

The issue lies in your implementation of data_generator function. When defining target_pos and target_neg, you are using pos_index and neg_index. These are used to traverse through the data_pos and data_neg arrays, and as you can notice, doesn’t indicate the number of positive and negative examples in a single batch. Instead, n_to_take indicates the number of positive and negative examples in a single batch, which is what you require for defining the number of positive and negative targets.

Let us know if this resolves your error.

Cheers,
Elemento

2 Likes

Thanks!
Yes, that was the issue.

Thanks!
Ive been stuck on this forever.

Thanks! This was quite tricky to debug and honestly surprising that it didn’t break unit tests until this far.