Course 3 week 1 exercise 6

Hi @drew_Frances,

I took a look at your assignment, which was named C3_W1_Assignment_2022_04_10_06_00_02.

Your exercise 4 was passing all the tests so I jumped right to exercise 8.

You’ll be surprised to know what the issue is.

You are making a typo in that exercise. It should be example_weight, not example_weights.

Fix that typo and you are good.

Cheers,
Mubsi

Hi @Mubsi

You are right! I am wondering why the unit test did not pick that up? I’m also wondering why I got 4/2 passed? Again thanks for your help!

Thank you,
Drew

Hi @drew_Frances,

Because example_weights is a global variable from cell #15.

And I did not see any test where you were passing 2/4 tests in any of the unit tests. Maybe you mean something else ?

Regards,
Mubsi

@Mubsi I am experiencing the same issue, and my lab id is vkawojmg. I’d be much appreciated if you could help me on this issue. Thanks.

Hi @CaoCao,

I left couple of comments in your Ex 2, it seems to be working now.

Happy learning,
Mubsi

Thank you! Now it works! :smiley:

@Mubsi hey mubsi
I am getting error in exercise 6 in traing_loop

   TypeError                                 Traceback (most recent call last)

in
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

in train_model(classifier, train_task, eval_task, n_steps, output_dir)
12 trainer - trax trainer
13 ‘’’
—> 14 rnd.seed(31) # Do NOT modify this random seed. This makes the notebook easier to replicate
15
16 ### START CODE HERE (Replace instances of ‘None’ with your code) ###

TypeError: ‘int’ object is not callable

Hello I’m facing the issue of this format

Grader Error: Grader feedback not found

What should I do? No matter what I do it is showing like this only

Hi @tejakalyanam,

At the start of the notebook there’s:

**Important Note on Submission to the AutoGrader**

Before submitting your assignment to the AutoGrader, please make sure you are not doing the following:

1. You have not added any extra print statement(s) in the assignment.
2. You have not added any extra code cell(s) in the assignment.
3. You have not changed any of the function parameters.
4. You are not using any global variables inside your graded exercises. Unless specifically instructed to do so, please refrain from it and use the local variables instead.
5. You are not changing the assignment code where it is not required, like creating extra variables.

If you do any of the following, you will get something like, Grader not found (or similarly unexpected) error upon submitting your assignment. Before asking for help/debugging the errors in your assignment, check for these first. If this is the case, and you don't remember the changes you have made, you can get a fresh copy of the assignment by following these instructions.

This clearly mentions if you do any of the listed things you will run into Grader not found error.

Please make sure you are not doing any of these listed things.

Best,
Mubsi

thanks for reply I have completed the assignment

Hello, I am also having issue with week 1 exercise 6. I believe I am not conceptualizing the issue surrounding passing a list or item of the list, [eval_task], to the function. All of my helper functions up to this point pass all tests.

Thanks!

Hi @Alejandro_Castro,

Could you share the error you are getting ?

Thanks,
Mubsi

---------------------------------------------------------------------------
LayerError                                Traceback (most recent call last)
<ipython-input-81-9216004d763d> in <module>
      2 # Take a look on how the eval_task is inside square brackets and
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-80-77fe42176b01> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     20                                 eval_tasks = eval_task[0], # The evaluation task
     21                                 output_dir = output_dir, # The output directory
---> 22                                 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
     23     ) 
     24 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    278 
    279     # Create the optimizer for the training loss function.
--> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    281 
    282     # Sync layers weights/state in memory effcient trainer layers.

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in <genexpr>(.0)
    278 
    279     # Create the optimizer for the training loss function.
--> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    281 
    282     # Sync layers weights/state in memory effcient trainer layers.

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _init_trainer(self, task)
    339           self._model,
    340           [task.loss_layer],
--> 341           shapes.signature(task.sample_batch)
    342       )
    343       if base.N_WEIGHTS_SHARDS > 1:

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _model_with_ends(model, end_layers, batch_signature)
   1028   # TODO(jonni): Redo this function as part of an initialization refactor?
   1029   metrics_layer = tl.Branch(*end_layers)
-> 1030   metrics_input_signature = model.output_signature(batch_signature)
   1031   _, _ = metrics_layer.init(metrics_input_signature)
   1032 

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in output_signature(self, input_signature)
    608   def output_signature(self, input_signature):
    609     """Returns output signature this layer would give for `input_signature`."""
--> 610     return self._forward_abstract(input_signature)[0]  # output only, not state
    611 
    612   def _forward_abstract(self, input_signature):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in _forward_abstract(self, input_signature)
    640       name, trace = self._name, _short_traceback(skip=7)
    641       raise LayerError(name, '_forward_abstract', self._caller, input_signature,
--> 642                        trace) from None
    643 
    644   # pylint: disable=protected-access

LayerError: Exception passing through layer Serial (in _forward_abstract):
  layer created in file [...]/<ipython-input-29-587fa2da4e34>, line 29
  layer input shapes: (ShapeDtype{shape:(16, 15), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
    lu.wrap_init(fun, params), avals, debug_info)

  File [...]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-29-587fa2da4e34>, line 29
  layer input shapes: (ShapeDtype{shape:(16, 15), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Mean (in pure_fn):
  layer created in file [...]/<ipython-input-29-587fa2da4e34>, line 14
  layer input shapes: ShapeDtype{shape:(16, 15, 256), dtype:float32}

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/core.py, line 704, in <lambda>
    return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))

  File [...]/_src/numpy/lax_numpy.py, line 2154, in mean
    normalizer = _axis_size(a, axis)

  File [...]/_src/numpy/lax_numpy.py, line 2139, in _axis_size
    size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))

  File [...]/jax/_src/util.py, line 391, in maybe_named_axis
    return if_named(axis) if named else if_pos(pos)

  File [...]/_src/numpy/lax_numpy.py, line 2139, in <lambda>
    size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))

IndexError: tuple index out of range

Hi @Mubsi,

My issue was the same as the OP. I had to fix the tl.Mean call in Q5. That then fixed Q6. I do agree with other students that, if possible, it would be good for Q5 not to pass if axis is not right.

Best,
Alejandro

1 Like

I also had a weird error on this exercise, and I can’t seem to track it down. Any help would be appreciated. I guess we tag @Mubsi? My lab id is kwuyifpmghvf.

This is the error I’m getting (start_indices must have an integer type); it’s the same as someone above, but all my unit tests are passing on everything previous to this, so I’m really not sure what it is.

Thanks!

Step      1: Total number of trainable weights: 2327042
Step      1: Ran 1 train steps in 1.71 secs
Step      1: train WeightedCategoryCrossEntropy |  0.69717056
Step      1: eval  WeightedCategoryCrossEntropy |  0.67841506
Step      1: eval      WeightedCategoryAccuracy |  0.68750000
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-140-7c2f23c889cd> in <module>
      7 
----> 8 w1_unittest.test_train_model(train_model(classifier(), train_task, [eval_task], 10, './model_test/'))

<ipython-input-138-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
    426         device=device, backend=backend, name=flat_fun.__name__,
--> 427         donated_invars=donated_invars, inline=inline)
    428     out_pytree_def = out_tree()

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1559   def bind(self, fun, *args, **params):
-> 1560     return call_bind(self, fun, *args, **params)
   1561 

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1550   tracers = map(top_trace.full_raise, args)
-> 1551   outs = primitive.process(top_trace, fun, tracers, params)
   1552   return map(full_lower, apply_todos(env_trace_todo(), outs))

/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1562   def process(self, trace, fun, tracers, params):
-> 1563     return trace.process_call(self, fun, tracers, params)
   1564 

/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    605   def process_call(self, primitive, f, tracers, params):
--> 606     return primitive.impl(f, *tracers, **params)
    607   process_map = process_call

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
    592   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 593                                *unsafe_map(arg_spec, args))
    594   try:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    261     else:
--> 262       ans = call(fun, *args)
    263       cache[key] = (ans, fun.stores)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    667   jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 668       fun, abstract_args, pe.debug_info_final(fun, "jit"))
    669   if any(isinstance(c, core.Tracer) for c in consts):

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
   1283     with core.new_sublevel():
-> 1284       jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1285     del fun, main

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1261     in_tracers = map(trace.new_arg, in_avals)
-> 1262     ans = fun.call_wrapped(*in_tracers)
   1263     out_tracers = map(trace.full_raise, ans)

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    217       (loss, state), gradients = forward_and_backward_fn(
--> 218           batch, weights, state, rng)
    219       weights, slots, stats = optimizer.tree_update(

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in value_and_grad_f(*args, **kwargs)
    883     else:
--> 884       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
    885     _check_scalar(ans)

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in _vjp(fun, has_aux, *primals)
   1964     flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
-> 1965     out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
   1966     out_tree, aux_tree = out_aux_trees()

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    115   else:
--> 116     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
    117 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
    100   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
    102   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
    504     fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    506     assert not env

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 

UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 1033
  layer input shapes: (ShapeDtype{shape:(16, 9), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-131-3ce0dd0c78b6>, line 29
  layer input shapes: ShapeDtype{shape:(16, 9), dtype:float32}

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
  layer created in file [...]/<ipython-input-131-3ce0dd0c78b6>, line 10
  layer input shapes: ShapeDtype{shape:(16, 9), dtype:float32}

  File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
    y = forward(self, x, *args, **kwargs)

  File [...]/trax/layers/core.py, line 182, in forward
    embedded = jnp.take(self.weights, x, axis=0, mode='clip')

  File [...]/_src/numpy/lax_numpy.py, line 4736, in take
    slice_sizes=tuple(slice_sizes))

  File [...]/_src/lax/lax.py, line 988, in gather
    indices_are_sorted=bool(indices_are_sorted))

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 274, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 143, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 150, in default_process_primitive
    return primitive.bind(*consts, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
    raise ValueError("start_indices must have an integer type")

ValueError: start_indices must have an integer type

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

LayerError                                Traceback (most recent call last)
<ipython-input-140-7c2f23c889cd> in <module>
      6     pass
      7 
----> 8 w1_unittest.test_train_model(train_model(classifier(), train_task, [eval_task], 10, './model_test/'))

<ipython-input-138-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     23     ) 
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###
     27 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    433           loss_acc, step_acc = 0.0, 0
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 
    437         # optimizer_metrics and loss are replicated on self.n_devices, a few

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    631 
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )
    635 

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    146     # NOTE: stats is a replicated dictionary of key to jnp arrays.
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 
    150     if logging.vlog_is_on(1) and ((step & step - 1) == 0):

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
    216       weights, slots = weights_and_slots
    217       (loss, state), gradients = forward_and_backward_fn(
--> 218           batch, weights, state, rng)
    219       weights, slots, stats = optimizer.tree_update(
    220           step, gradients, weights, slots, opt_params, store_slots=False)

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    604       name, trace = self._name, _short_traceback(skip=3)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 
    608   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 1033
  layer input shapes: (ShapeDtype{shape:(16, 9), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-131-3ce0dd0c78b6>, line 29
  layer input shapes: ShapeDtype{shape:(16, 9), dtype:float32}

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
  layer created in file [...]/<ipython-input-131-3ce0dd0c78b6>, line 10
  layer input shapes: ShapeDtype{shape:(16, 9), dtype:float32}

  File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
    y = forward(self, x, *args, **kwargs)

  File [...]/trax/layers/core.py, line 182, in forward
    embedded = jnp.take(self.weights, x, axis=0, mode='clip')

  File [...]/_src/numpy/lax_numpy.py, line 4736, in take
    slice_sizes=tuple(slice_sizes))

  File [...]/_src/lax/lax.py, line 988, in gather
    indices_are_sorted=bool(indices_are_sorted))

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 274, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 143, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 150, in default_process_primitive
    return primitive.bind(*consts, **params)

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
    raise ValueError("start_indices must have an integer type")

ValueError: start_indices must have an integer type

Hi @Mubsi ,
I am getting the below. Didn’t able to trace what is causing the error.
Checked previous output also.No output mismatch is there.

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-90-9216004d763d> in <module>
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-87-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
    426         device=device, backend=backend, name=flat_fun.__name__,
--> 427         donated_invars=donated_invars, inline=inline)
    428     out_pytree_def = out_tree()

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1559   def bind(self, fun, *args, **params):
-> 1560     return call_bind(self, fun, *args, **params)
   1561 

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1549       params_tuple, out_axes_transforms)
-> 1550   tracers = map(top_trace.full_raise, args)
   1551   outs = primitive.process(top_trace, fun, tracers, params)

/opt/conda/lib/python3.7/site-packages/jax/_src/util.py in safe_map(f, *args)
     40     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 41   return list(map(f, *args))
     42 

/opt/conda/lib/python3.7/site-packages/jax/core.py in full_raise(self, val)
    384         raise escaped_tracer_error(
--> 385             val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}")
    386     elif val._trace.level < level:

UnfilteredStackTrace: jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: Can't lift sublevels 1 to 0.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

UnexpectedTracerError                     Traceback (most recent call last)
<ipython-input-90-9216004d763d> in <module>
      2 # Take a look on how the eval_task is inside square brackets and
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-87-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     23     ) 
     24 
---> 25     training_loop.run(n_steps = n_steps)
     26     ### END CODE HERE ###
     27 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
    433           loss_acc, step_acc = 0.0, 0
    434 
--> 435         loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
    436 
    437         # optimizer_metrics and loss are replicated on self.n_devices, a few

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
    631 
    632     (loss, stats) = trainer.one_step(
--> 633         batch, rng, step=step, learning_rate=learning_rate
    634     )
    635 

/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
    146     # NOTE: stats is a replicated dictionary of key to jnp arrays.
    147     (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148         (weights, self._slots), step, self._opt_params, batch, state, rng)
    149 
    150     if logging.vlog_is_on(1) and ((step & step - 1) == 0):

/opt/conda/lib/python3.7/site-packages/jax/core.py in full_raise(self, val)
    383       else:
    384         raise escaped_tracer_error(
--> 385             val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}")
    386     elif val._trace.level < level:
    387       if val._trace.sublevel > sublevel:

UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: Can’t lift sublevels 1 to 0.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the jax.checking_leaks context manager.

Thanks,
Nikhil

Hi @Mubsi
Found the issue. Code is running fine now.
Thanks

Tagging @Elemento who seems to be commenting on other posts. Link to original post, from just under 2 weeks ago: Course 3 week 1 exercise 6 - #37 by Terwiliger

My exercise 6 train_model function hits the stack trace mentioned above, but it does run through at least 1 step correctly before encountering an error:

ValueError: start_indices must have an integer type

I have run through this every which way and I’m not sure what is causing this error. My suspicion is that it’s something in either the data_generator function or the tl.Mean layer is returning a float that can’t be coerced into an int for some reason.

I would really appreciate any help - I’ve spent hours tweaking settings and dimensions and other things, and I’ve lost half a month of course time just trying to get past this.

Thank you.

Hey @Terwiliger,
Can you please DM me your implementation of the train_model function, so that I can try to figure out the error?

Cheers,
Elemento

Hello Mubsi
I am also having the problem with [eval_task].
This is my notebook ID: zsmmfdtgprnd
I have defined eval_tasks = eval_tasks
and passed all the tests without any observation, but the grader rejected all the answers because I have used a global variable.
I really do not understand this problem.
Thanks for helping.
Can you help me?