Week 1 Exercise 6-- start indices must have integer type

I’m getting the following error, and have no idea why :-/. I passed everything up until exercise 6

---------------------------------------------------------------------------
LayerError                                Traceback (most recent call last)
<ipython-input-434-9216004d763d> in <module>
      2 # Take a look on how the eval_task is inside square brackets and
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-433-ba929cad1e07> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     20                                 eval_tasks= eval_task, # The evaluation task
     21                                 output_dir=output_dir, # The output directory
---> 22                                 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
     23     ) 
     24 

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    246     if not use_memory_efficient_trainer:
    247       if _is_uninitialized(self._model):
--> 248         self._model.init(self._batch_signature)
    249       self._eval_model.rng = self.new_rng()
    250       if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
    309       name, trace = self._name, _short_traceback(skip=3)
    310       raise LayerError(name, 'init', self._caller,
--> 311                        input_signature, trace) from None
    312 
    313   def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-422-840dd88c464c>, line 30
  layer input shapes: (ShapeDtype{shape:(16, 15), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32})

  File [...]/trax/layers/combinators.py, line 106, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer Embedding_9088_256 (in _forward_abstract):
  layer created in file [...]/<ipython-input-422-840dd88c464c>, line 11
  layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}

  File [...]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
    lu.wrap_init(fun, params), avals, debug_info)

  File [...]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
  layer created in file [...]/<ipython-input-422-840dd88c464c>, line 11
  layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}

  File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
    y = forward(self, x, *args, **kwargs)

  File [...]/trax/layers/core.py, line 182, in forward
    embedded = jnp.take(self.weights, x, axis=0, mode='clip')

  File [...]/_src/numpy/lax_numpy.py, line 4736, in take
    slice_sizes=tuple(slice_sizes))

  File [...]/_src/lax/lax.py, line 988, in gather
    indices_are_sorted=bool(indices_are_sorted))

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
    raise ValueError("start_indices must have an integer type")

ValueError: start_indices must have an integer type
1 Like

Hi Alexis. Would you please show me the specific code snippet where you got the error? And please check the start_indices type. You can do it by running type(start_indices) on a new cell in your notebook.

I got the same error. For me, this was coming because of “input datatype” which was float. This trace to Excersize 2 data_generator(). The returns from data_generator() are required to be integers. I think the unit test only check the value of the data type. So only by changing the data type of the return array (i.e. inputs), the problem was solved for me.

I know it is late hope you found your answer. I hope this help in future.

2 Likes

start_indices is defined within the bowels of trax/jax, it’s not possible to extract this directly from the exercise Notebook, unless one digs through the source code of the imported packages…

I ran into this too. The issue for me was that the input returned by data_generator() in Exercise 2 did not specify any data type. Adding that fixed the problem so my model now trains.

Upvoted. This approach fixed my similar issue.