I am getting ‘TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2)’ for exercise 6 in this coding assignment. All previous exercises are passing all test cases, and my axes are also correctly defined. I have been trying to figure out the problem without luck. My lab id is omyayqotxydh and I would appreciate some help - thank you in advance!
Hi @Brainism
Have you solved your error? If you still have trouble, you can private message me your notebook attached and I can try to help. (Mentors cannot access notebooks by ids).
Cheers
@Brainism Did you get this resolved? I had the same issue, and could not figure out what went wrong.
Hey @Elemento
I have the same error as OP with Exercise 6. Everything up to this passes unit tests, but this exercise has been really tricky. I’m assuming there is something wrong with my classifier in exercise 5
Would appreciate a nudge in the right direction. Below is my errorstack. Apologies in advance if I am adding the traceback incorrectly.
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-46-9216004d763d> in <module>
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)
<ipython-input-45-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
24
---> 25 training_loop.run(n_steps = n_steps)
26 ### END CODE HERE ###
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
434
--> 435 loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
436
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
632 (loss, stats) = trainer.one_step(
--> 633 batch, rng, step=step, learning_rate=learning_rate
634 )
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
147 (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148 (weights, self._slots), step, self._opt_params, batch, state, rng)
149
/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
182 try:
--> 183 return fun(*args, **kwargs)
184 except Exception as e:
/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
426 device=device, backend=backend, name=flat_fun.__name__,
--> 427 donated_invars=donated_invars, inline=inline)
428 out_pytree_def = out_tree()
/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1559 def bind(self, fun, *args, **params):
-> 1560 return call_bind(self, fun, *args, **params)
1561
/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1550 tracers = map(top_trace.full_raise, args)
-> 1551 outs = primitive.process(top_trace, fun, tracers, params)
1552 return map(full_lower, apply_todos(env_trace_todo(), outs))
/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1562 def process(self, trace, fun, tracers, params):
-> 1563 return trace.process_call(self, fun, tracers, params)
1564
/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
605 def process_call(self, primitive, f, tracers, params):
--> 606 return primitive.impl(f, *tracers, **params)
607 process_map = process_call
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
592 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 593 *unsafe_map(arg_spec, args))
594 try:
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
261 else:
--> 262 ans = call(fun, *args)
263 cache[key] = (ans, fun.stores)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
667 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 668 fun, abstract_args, pe.debug_info_final(fun, "jit"))
669 if any(isinstance(c, core.Tracer) for c in consts):
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
1283 with core.new_sublevel():
-> 1284 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1285 del fun, main
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1261 in_tracers = map(trace.new_arg, in_avals)
-> 1262 ans = fun.call_wrapped(*in_tracers)
1263 out_tracers = map(trace.full_raise, ans)
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
217 (loss, state), gradients = forward_and_backward_fn(
--> 218 batch, weights, state, rng)
219 weights, slots, stats = optimizer.tree_update(
/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
182 try:
--> 183 return fun(*args, **kwargs)
184 except Exception as e:
/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in value_and_grad_f(*args, **kwargs)
883 else:
--> 884 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
885 _check_scalar(ans)
/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in _vjp(fun, has_aux, *primals)
1964 flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
-> 1965 out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
1966 out_tree, aux_tree = out_aux_trees()
/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
115 else:
--> 116 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
117
/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
102 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
504 fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
506 assert not env
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
605 raise LayerError(name, 'pure_fn',
--> 606 self._caller, signature(x), trace) from None
607
UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py, line 1033
layer input shapes: (ShapeDtype{shape:(16, 26), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int8})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
layer created in file [...]/<ipython-input-42-e22a181c30d5>, line 12
layer input shapes: (ShapeDtype{shape:(16, 2), dtype:float32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int8})
File [...]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)
File [...]/trax/layers/base.py, line 784, in _forward
return f(*xs)
File [...]/trax/layers/metrics.py, line 273, in f
model_output, targets, label_smoothing)
File [...]/trax/layers/metrics.py, line 649, in _category_cross_entropy
return - jnp.sum(target_distributions * model_log_distributions, axis=-1)
File [...]/site-packages/jax/core.py, line 506, in __mul__
def __mul__(self, other): return self.aval._mul(self, other)
File [...]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
return binary_op(self, other)
File [...]/_src/numpy/lax_numpy.py, line 431, in fn
return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
File [...]/_src/lax/lax.py, line 348, in mul
return mul_p.bind(x, y)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/ad.py, line 274, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
val_out = primitive.bind(*primals, **params)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File [...]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
LayerError Traceback (most recent call last)
<ipython-input-46-9216004d763d> in <module>
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)
<ipython-input-45-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
23 )
24
---> 25 training_loop.run(n_steps = n_steps)
26 ### END CODE HERE ###
27
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
433 loss_acc, step_acc = 0.0, 0
434
--> 435 loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
436
437 # optimizer_metrics and loss are replicated on self.n_devices, a few
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
631
632 (loss, stats) = trainer.one_step(
--> 633 batch, rng, step=step, learning_rate=learning_rate
634 )
635
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
146 # NOTE: stats is a replicated dictionary of key to jnp arrays.
147 (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148 (weights, self._slots), step, self._opt_params, batch, state, rng)
149
150 if logging.vlog_is_on(1) and ((step & step - 1) == 0):
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
216 weights, slots = weights_and_slots
217 (loss, state), gradients = forward_and_backward_fn(
--> 218 batch, weights, state, rng)
219 weights, slots, stats = optimizer.tree_update(
220 step, gradients, weights, slots, opt_params, store_slots=False)
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
604 name, trace = self._name, _short_traceback(skip=3)
605 raise LayerError(name, 'pure_fn',
--> 606 self._caller, signature(x), trace) from None
607
608 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py, line 1033
layer input shapes: (ShapeDtype{shape:(16, 26), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int8})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
layer created in file [...]/<ipython-input-42-e22a181c30d5>, line 12
layer input shapes: (ShapeDtype{shape:(16, 2), dtype:float32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int8})
File [...]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)
File [...]/trax/layers/base.py, line 784, in _forward
return f(*xs)
File [...]/trax/layers/metrics.py, line 273, in f
model_output, targets, label_smoothing)
File [...]/trax/layers/metrics.py, line 649, in _category_cross_entropy
return - jnp.sum(target_distributions * model_log_distributions, axis=-1)
File [...]/site-packages/jax/core.py, line 506, in __mul__
def __mul__(self, other): return self.aval._mul(self, other)
File [...]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
return binary_op(self, other)
File [...]/_src/numpy/lax_numpy.py, line 431, in fn
return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
File [...]/_src/lax/lax.py, line 348, in mul
return mul_p.bind(x, y)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/ad.py, line 274, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
val_out = primitive.bind(*primals, **params)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File [...]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).
Hey @krisoye,
Well, one of the issues I can see clearly in your error stack is that you have used a global variable, namely model
in your implementation. You are supposed to use the parameter of the function, i.e., classifier
instead. Unless and until, you are explicitly told to use a global variable, ensure that you haven’t used a global variable in your implementations. Let me know if this helps.
Cheers,
Elemento
Hey @Elemento,
I’m not following. I don’t see a global reference to model in my train_model function. Is this the function that you were referring to? Would you like me to pm you my notebook?
Hey @krisoye,
I am referring to the above line of code. In this, you are supposed to use the parameter of the function classifier
instead of the global variable model
. Note that I have highlighted the variables like classifier
and model
for easier distinction. Let me know if this helps.
Cheers,
Elemento
Hi,
I am having a similar issue. I tried modifying the axis in tl.Mean() like people in different threads suggested, but the issue persists (although the mismatching dimensions do change).
TBH I am not sure what the suggested solution in the last post is, that cell is read-only.
Hey @guszejnov,
Welcome, and we are glad that you could become a part of our community
The issue lies in your implementation of data_generator
function. When defining target_pos
and target_neg
, you are using pos_index
and neg_index
. These are used to traverse through the data_pos
and data_neg
arrays, and as you can notice, doesn’t indicate the number of positive and negative examples in a single batch. Instead, n_to_take
indicates the number of positive and negative examples in a single batch, which is what you require for defining the number of positive and negative targets.
Let us know if this resolves your error.
Cheers,
Elemento
Thanks!
Yes, that was the issue.
Thanks!
Ive been stuck on this forever.
Thanks! This was quite tricky to debug and honestly surprising that it didn’t break unit tests until this far.