Course Three Week 1 Assignment 1 Ex 6 - Problem with Implementing Training

Hi Folks:

I am implementing training model. I get big stack trace:

LayerError Traceback (most recent call last)
in
6 pass
7
----> 8 w1_unittest.test_train_model(train_model(classifier(), train_task, [eval_task], 10, ‘./model_test/’))

in train_model(classifier, train_task, eval_task, n_steps, output_dir)
22 eval_tasks=eval_task[0], # The evaluation task
23 output_dir=output_dir, # The output directory
—> 24 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.

ValueError: start_indices must have an integer type

At this point, I don’t know where to look to fix the problem. I took the advice concerning setting Mean layer’s axis to 1 (columns). In exercise 4, I get 8/2 on my tests. In exercise 2, I get 10/1. At this stage I don’t know what to look at.

My lab id is tqxhipjp

Thanks,
Drew

The comment in the code there is misleading. If you take it literally, you will infer that you should be providing a single evaluation task. Don’t take it literally…it should be a list.

https://trax-ml.readthedocs.io/en/latest/trax.supervised.html?highlight=training%20loop#trax.supervised.training.Loop

Notice that if what is being passed to a function is this my_func([eval_task]) then the parameter name in my_func should reflect the assumption / requirement for a multivalued object and its name should be eval_tasks, plural, not eval_task singular. That attention to detail would have made it easier to know what to pass down to the training loop. When it comes to variable names nomen est omen should be the guiding principle. Notice also that the comment outside the function the eval_task is passed as a list explicitly, so take that into account… could have been elegantly achieved by just naming the variable properly.

@ai_curious (my bad) Thanks for the response!

I pasted the wrong output (I was experimenting). I re-checked the code, put a set_trace(). eval_task is a list. I get the following:

LayerError Traceback (most recent call last)
in
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

in train_model(classifier, train_task, eval_task, n_steps, output_dir)
22 eval_tasks= eval_task, # The evaluation task
23 output_dir=output_dir, # The output directory
—> 24 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
25 )
26

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):

Again, how do I proceed? And again, thanks for your help!

Cheers,
Drew

I can’t tell what the reported error is. Was there more in the trace?

Hi @ai_curious:

Yes there was more. Here are two stack traces. One where I have eval_tasks = eval_task. The other eval_tasks = eval_task[0]

Cheers,
Drew

LayerError Traceback (most recent call last)
in
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

in train_model(classifier, train_task, eval_task, n_steps, output_dir)
22 eval_tasks= eval_task, # The evaluation task
23 output_dir=output_dir, # The output directory
—> 24 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
25 )
26

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 30
layer input shapes: (ShapeDtype{shape:(16, 15), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32})

File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer Embedding_9088_256 (in _forward_abstract):
layer created in file […]/, line 11
layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}

File […]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

File […]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
layer created in file […]/, line 11
layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}

File […]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)

File […]/trax/layers/core.py, line 182, in forward
embedded = jnp.take(self.weights, x, axis=0, mode=‘clip’)

File […]/_src/numpy/lax_numpy.py, line 4736, in take
slice_sizes=tuple(slice_sizes))

File […]/_src/lax/lax.py, line 988, in gather
indices_are_sorted=bool(indices_are_sorted))

File […]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)

File […]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)

File […]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

File […]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
raise ValueError(“start_indices must have an integer type”)

ValueError: start_indices must have an integer type


LayerError Traceback (most recent call last)
in
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

in train_model(classifier, train_task, eval_task, n_steps, output_dir)
22 eval_tasks= eval_task[0], # The evaluation task
23 output_dir=output_dir, # The output directory
—> 24 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
25 )
26

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 30
layer input shapes: (ShapeDtype{shape:(16, 15), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32})

File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer Embedding_9088_256 (in _forward_abstract):
layer created in file […]/, line 11
layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}

File […]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

File […]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer Embedding_9088_256 (in pure_fn):
layer created in file […]/, line 11
layer input shapes: ShapeDtype{shape:(16, 15), dtype:float32}

File […]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)

File […]/trax/layers/core.py, line 182, in forward
embedded = jnp.take(self.weights, x, axis=0, mode=‘clip’)

File […]/_src/numpy/lax_numpy.py, line 4736, in take
slice_sizes=tuple(slice_sizes))

File […]/_src/lax/lax.py, line 988, in gather
indices_are_sorted=bool(indices_are_sorted))

File […]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)

File […]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)

File […]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

File […]/_src/lax/lax.py, line 4354, in _gather_dtype_rule
raise ValueError(“start_indices must have an integer type”)

ValueError: start_indices must have an integer type

It looks like something is going wrong with inputs to the function. Maybe with the contents of train_task or eval_task? Did you carefully compare the generated output with expected results after the # Test the train_generator code block?

From a related thread we learned that the answer to this question was “No, we did not carefully compare the generated output with the expected results.” Or, if we did, we decided to ignore the discrepancy and hope for the best later on. Kids, don’t try this at home. If the helper functions aren’t generating expected output when you run them standalone the chances are extremely high they won’t generate expected output when you reuse them later.