C3_W1_Assignment - Exercise 6 - train_model

Hi,

I am stuck in the train exercise of this assignment. The thing is that I got all previous exercises perfectly passed and now it seems like there is a problem with the shape of serial layer which was correctly implemented since it passed the tests. Any thoughts??

Thank you in advance.

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-46-9216004d763d> in <module>
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-45-27a2fdc9a148> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
    426         device=device, backend=backend, name=flat_fun.__name__,
--> 427         donated_invars=donated_invars, inline=inline)
    428     out_pytree_def = out_tree()

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1559   def bind(self, fun, *args, **params):
-> 1560     return call_bind(self, fun, *args, **params)
   1561 

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1550   tracers = map(top_trace.full_raise, args)
-> 1551   outs = primitive.process(top_trace, fun, tracers, params)
   1552   return map(full_lower, apply_todos(env_trace_todo(), outs))

/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1562   def process(self, trace, fun, tracers, params):
-> 1563     return trace.process_call(self, fun, tracers, params)
   1564 

/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    605   def process_call(self, primitive, f, tracers, params):
--> 606     return primitive.impl(f, *tracers, **params)
    607   process_map = process_call

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
    592   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 593                                *unsafe_map(arg_spec, args))
    594   try:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    261     else:
--> 262       ans = call(fun, *args)
    263       cache[key] = (ans, fun.stores)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    667   jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 668       fun, abstract_args, pe.debug_info_final(fun, "jit"))
    669   if any(isinstance(c, core.Tracer) for c in consts):

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
   1283     with core.new_sublevel():
-> 1284       jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1285     del fun, main

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1261     in_tracers = map(trace.new_arg, in_avals)
-> 1262     ans = fun.call_wrapped(*in_tracers)
   1263     out_tracers = map(trace.full_raise, ans)

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    217       (loss, state), gradients = forward_and_backward_fn(
--> 218           batch, weights, state, rng)
    219       weights, slots, stats = optimizer.tree_update(

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in value_and_grad_f(*args, **kwargs)
    883     else:
--> 884       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
    885     _check_scalar(ans)

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in _vjp(fun, has_aux, *primals)
   1964     flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
-> 1965     out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
   1966     out_tree, aux_tree = out_aux_trees()

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    115   else:
--> 116     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
    117 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
    100   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
    102   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
    504     fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    506     assert not env

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 

UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 1033
  layer input shapes: (ShapeDtype{shape:(16, 26), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
  layer created in file [...]/<ipython-input-42-e22a181c30d5>, line 12
  layer input shapes: (ShapeDtype{shape:(16, 2), dtype:float32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32})

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/metrics.py, line 273, in f
    model_output, targets, label_smoothing)

  File [...]/trax/layers/metrics.py, line 649, in _category_cross_entropy
    return - jnp.sum(target_distributions * model_log_distributions, axis=-1)

  File [...]/site-packages/jax/core.py, line 506, in __mul__
    def __mul__(self, other): return self.aval._mul(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
    return binary_op(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 431, in fn
    return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

  File [...]/_src/lax/lax.py, line 348, in mul
    return mul_p.bind(x, y)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 274, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

LayerError                                Traceback (most recent call last)
<ipython-input-46-9216004d763d> in <module>
      2 # Take a look on how the eval_task is inside square brackets and
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-45-27a2fdc9a148> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     23     ) 
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###
     27 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    433           loss_acc, step_acc = 0.0, 0
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 
    437         # optimizer_metrics and loss are replicated on self.n_devices, a few

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    631 
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )
    635 

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    146     # NOTE: stats is a replicated dictionary of key to jnp arrays.
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 
    150     if logging.vlog_is_on(1) and ((step & step - 1) == 0):

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    216       weights, slots = weights_and_slots
    217       (loss, state), gradients = forward_and_backward_fn(
--> 218           batch, weights, state, rng)
    219       weights, slots, stats = optimizer.tree_update(
    220           step, gradients, weights, slots, opt_params, store_slots=False)

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    604       name, trace = self._name, _short_traceback(skip=3)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 
    608   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 1033
  layer input shapes: (ShapeDtype{shape:(16, 26), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
  layer created in file [...]/<ipython-input-42-e22a181c30d5>, line 12
  layer input shapes: (ShapeDtype{shape:(16, 2), dtype:float32}, ShapeDtype{shape:(32,), dtype:int32}, ShapeDtype{shape:(32,), dtype:int32})

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/metrics.py, line 273, in f
    model_output, targets, label_smoothing)

  File [...]/trax/layers/metrics.py, line 649, in _category_cross_entropy
    return - jnp.sum(target_distributions * model_log_distributions, axis=-1)

  File [...]/site-packages/jax/core.py, line 506, in __mul__
    def __mul__(self, other): return self.aval._mul(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
    return binary_op(self, other)

  File [...]/_src/numpy/lax_numpy.py, line 431, in fn
    return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

  File [...]/_src/lax/lax.py, line 348, in mul
    return mul_p.bind(x, y)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 274, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
    raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

**TypeError: mul got incompatible shapes for broadcasting: (32, 2), (16, 2).**

Hi @pablonatal

The most probable cause should be your data_generator implementation. So try to double-check that your data_generator can work with different batch sizes. (Try playing with the cell below.)

Also, you can check how you generate the list of targets for positive and negative examples (target_pos and target_neg) and that you use n_to_take variable for it (and not len_data_pos and len_data_neg).

These are common mistakes. If you won’t be able to fix it yourself you can private message me your notebook (how to download) and I will take a look.

HI @arvyzukai .

Thank you very much for your help. That was a good point to review. The code was properly working for different batch sizes. However, n_to_take variable was not properly set. It is now fixed.

Thank you.
Palo