Hi @ai_curious:
Yes there was more. Here are two stack traces. One where I have eval_tasks = eval_task. The other eval_tasks = eval_task[0]
Cheers,
Drew
LayerError Traceback (most recent call last)
in
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)
in train_model(classifier, train_task, eval_task, n_steps, output_dir)
22 eval_tasks= eval_task, # The evaluation task
23 output_dir=output_dir, # The output directory
—> 24 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
25 )
26
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):
LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 30
layer input shapes: (ShapeDtype{shape:(16, 15), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32})
File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer Embedding_9088_256 (in _forward_abstract):
layer created in file […]/, line 11
layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}
File […]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)
File […]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File […]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
layer created in file […]/, line 11
layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}
File […]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File […]/trax/layers/core.py, line 182, in forward
embedded = jnp.take(self.weights, x, axis=0, mode=‘clip’)
File […]/_src/numpy/lax_numpy.py, line 4736, in take
slice_sizes=tuple(slice_sizes))
File […]/_src/lax/lax.py, line 988, in gather
indices_are_sorted=bool(indices_are_sorted))
File […]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File […]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File […]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File […]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
raise ValueError(“start_indices must have an integer type”)
ValueError: start_indices must have an integer type
LayerError Traceback (most recent call last)
in
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)
in train_model(classifier, train_task, eval_task, n_steps, output_dir)
22 eval_tasks= eval_task[0], # The evaluation task
23 output_dir=output_dir, # The output directory
—> 24 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
25 )
26
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):
LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 30
layer input shapes: (ShapeDtype{shape:(16, 15), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32})
File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer Embedding_9088_256 (in _forward_abstract):
layer created in file […]/, line 11
layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}
File […]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)
File […]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File […]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
layer created in file […]/, line 11
layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}
File […]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File […]/trax/layers/core.py, line 182, in forward
embedded = jnp.take(self.weights, x, axis=0, mode=‘clip’)
File […]/_src/numpy/lax_numpy.py, line 4736, in take
slice_sizes=tuple(slice_sizes))
File […]/_src/lax/lax.py, line 988, in gather
indices_are_sorted=bool(indices_are_sorted))
File […]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File […]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File […]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File […]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
raise ValueError(“start_indices must have an integer type”)
ValueError: start_indices must have an integer type