I also had a weird error on this exercise, and I can’t seem to track it down. Any help would be appreciated. I guess we tag @Mubsi? My lab id is kwuyifpmghvf.
This is the error I’m getting (start_indices must have an integer type); it’s the same as someone above, but all my unit tests are passing on everything previous to this, so I’m really not sure what it is.
Thanks!
Step 1: Total number of trainable weights: 2327042
Step 1: Ran 1 train steps in 1.71 secs
Step 1: train WeightedCategoryCrossEntropy | 0.69717056
Step 1: eval WeightedCategoryCrossEntropy | 0.67841506
Step 1: eval WeightedCategoryAccuracy | 0.68750000
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-140-7c2f23c889cd> in <module>
7
----> 8 w1_unittest.test_train_model(train_model(classifier(), train_task, [eval_task], 10, './model_test/'))
<ipython-input-138-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
24
---> 25 training_loop.run(n_steps = n_steps)
26 ### END CODE HERE ###
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
434
--> 435 loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
436
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
632 (loss, stats) = trainer.one_step(
--> 633 batch, rng, step=step, learning_rate=learning_rate
634 )
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
147 (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148 (weights, self._slots), step, self._opt_params, batch, state, rng)
149
/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
182 try:
--> 183 return fun(*args, **kwargs)
184 except Exception as e:
/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
426 device=device, backend=backend, name=flat_fun.__name__,
--> 427 donated_invars=donated_invars, inline=inline)
428 out_pytree_def = out_tree()
/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1559 def bind(self, fun, *args, **params):
-> 1560 return call_bind(self, fun, *args, **params)
1561
/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1550 tracers = map(top_trace.full_raise, args)
-> 1551 outs = primitive.process(top_trace, fun, tracers, params)
1552 return map(full_lower, apply_todos(env_trace_todo(), outs))
/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1562 def process(self, trace, fun, tracers, params):
-> 1563 return trace.process_call(self, fun, tracers, params)
1564
/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
605 def process_call(self, primitive, f, tracers, params):
--> 606 return primitive.impl(f, *tracers, **params)
607 process_map = process_call
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
592 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 593 *unsafe_map(arg_spec, args))
594 try:
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
261 else:
--> 262 ans = call(fun, *args)
263 cache[key] = (ans, fun.stores)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
667 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 668 fun, abstract_args, pe.debug_info_final(fun, "jit"))
669 if any(isinstance(c, core.Tracer) for c in consts):
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
1283 with core.new_sublevel():
-> 1284 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1285 del fun, main
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1261 in_tracers = map(trace.new_arg, in_avals)
-> 1262 ans = fun.call_wrapped(*in_tracers)
1263 out_tracers = map(trace.full_raise, ans)
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
217 (loss, state), gradients = forward_and_backward_fn(
--> 218 batch, weights, state, rng)
219 weights, slots, stats = optimizer.tree_update(
/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
182 try:
--> 183 return fun(*args, **kwargs)
184 except Exception as e:
/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in value_and_grad_f(*args, **kwargs)
883 else:
--> 884 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
885 _check_scalar(ans)
/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in _vjp(fun, has_aux, *primals)
1964 flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
-> 1965 out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
1966 out_tree, aux_tree = out_aux_trees()
/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
115 else:
--> 116 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
117
/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
102 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
504 fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
506 assert not env
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
605 raise LayerError(name, 'pure_fn',
--> 606 self._caller, signature(x), trace) from None
607
UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py, line 1033
layer input shapes: (ShapeDtype{shape:(16, 9), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-131-3ce0dd0c78b6>, line 29
layer input shapes: ShapeDtype{shape:(16, 9), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
layer created in file [...]/<ipython-input-131-3ce0dd0c78b6>, line 10
layer input shapes: ShapeDtype{shape:(16, 9), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 182, in forward
embedded = jnp.take(self.weights, x, axis=0, mode='clip')
File [...]/_src/numpy/lax_numpy.py, line 4736, in take
slice_sizes=tuple(slice_sizes))
File [...]/_src/lax/lax.py, line 988, in gather
indices_are_sorted=bool(indices_are_sorted))
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/ad.py, line 274, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
val_out = primitive.bind(*primals, **params)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 143, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 150, in default_process_primitive
return primitive.bind(*consts, **params)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File [...]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
raise ValueError("start_indices must have an integer type")
ValueError: start_indices must have an integer type
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
LayerError Traceback (most recent call last)
<ipython-input-140-7c2f23c889cd> in <module>
6 pass
7
----> 8 w1_unittest.test_train_model(train_model(classifier(), train_task, [eval_task], 10, './model_test/'))
<ipython-input-138-918bf0dfcb18> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
23 )
24
---> 25 training_loop.run(n_steps = n_steps)
26 ### END CODE HERE ###
27
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
433 loss_acc, step_acc = 0.0, 0
434
--> 435 loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
436
437 # optimizer_metrics and loss are replicated on self.n_devices, a few
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
631
632 (loss, stats) = trainer.one_step(
--> 633 batch, rng, step=step, learning_rate=learning_rate
634 )
635
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
146 # NOTE: stats is a replicated dictionary of key to jnp arrays.
147 (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 148 (weights, self._slots), step, self._opt_params, batch, state, rng)
149
150 if logging.vlog_is_on(1) and ((step & step - 1) == 0):
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
216 weights, slots = weights_and_slots
217 (loss, state), gradients = forward_and_backward_fn(
--> 218 batch, weights, state, rng)
219 weights, slots, stats = optimizer.tree_update(
220 step, gradients, weights, slots, opt_params, store_slots=False)
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
604 name, trace = self._name, _short_traceback(skip=3)
605 raise LayerError(name, 'pure_fn',
--> 606 self._caller, signature(x), trace) from None
607
608 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py, line 1033
layer input shapes: (ShapeDtype{shape:(16, 9), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-131-3ce0dd0c78b6>, line 29
layer input shapes: ShapeDtype{shape:(16, 9), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
layer created in file [...]/<ipython-input-131-3ce0dd0c78b6>, line 10
layer input shapes: ShapeDtype{shape:(16, 9), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 182, in forward
embedded = jnp.take(self.weights, x, axis=0, mode='clip')
File [...]/_src/numpy/lax_numpy.py, line 4736, in take
slice_sizes=tuple(slice_sizes))
File [...]/_src/lax/lax.py, line 988, in gather
indices_are_sorted=bool(indices_are_sorted))
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/ad.py, line 274, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File [...]/jax/interpreters/ad.py, line 449, in standard_jvp
val_out = primitive.bind(*primals, **params)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 143, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 150, in default_process_primitive
return primitive.bind(*consts, **params)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File [...]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
raise ValueError("start_indices must have an integer type")
ValueError: start_indices must have an integer type