I get the following error and i cant figure out the reason. Can you help.
---------------------------------------------------------------------------
LayerError Traceback (most recent call last)
<ipython-input-80-9216004d763d> in <module>
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)
<ipython-input-79-f5a1eb0b2c7f> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
20 eval_tasks=eval_task[0], # The evaluation task
21 output_dir=output_dir, # The output directory
---> 22 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
23 )
24
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
278
279 # Create the optimizer for the training loss function.
--> 280 self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
281
282 # Sync layers weights/state in memory effcient trainer layers.
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in <genexpr>(.0)
278
279 # Create the optimizer for the training loss function.
--> 280 self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
281
282 # Sync layers weights/state in memory effcient trainer layers.
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _init_trainer(self, task)
339 self._model,
340 [task.loss_layer],
--> 341 shapes.signature(task.sample_batch)
342 )
343 if base.N_WEIGHTS_SHARDS > 1:
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _model_with_ends(model, end_layers, batch_signature)
1028 # TODO(jonni): Redo this function as part of an initialization refactor?
1029 metrics_layer = tl.Branch(*end_layers)
-> 1030 metrics_input_signature = model.output_signature(batch_signature)
1031 _, _ = metrics_layer.init(metrics_input_signature)
1032
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in output_signature(self, input_signature)
608 def output_signature(self, input_signature):
609 """Returns output signature this layer would give for `input_signature`."""
--> 610 return self._forward_abstract(input_signature)[0] # output only, not state
611
612 def _forward_abstract(self, input_signature):
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in _forward_abstract(self, input_signature)
640 name, trace = self._name, _short_traceback(skip=7)
641 raise LayerError(name, '_forward_abstract', self._caller, input_signature,
--> 642 trace) from None
643
644 # pylint: disable=protected-access
LayerError: Exception passing through layer Serial (in _forward_abstract):
layer created in file [...]/<ipython-input-34-1421af8c6f34>, line 30
layer input shapes: (ShapeDtype{shape:(16, 15), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:float32})
File [...]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)
File [...]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File [...]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-34-1421af8c6f34>, line 30
layer input shapes: (ShapeDtype{shape:(16, 15), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:float32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Mean (in pure_fn):
layer created in file [...]/<ipython-input-34-1421af8c6f34>, line 15
layer input shapes: ShapeDtype{shape:(16, 15, 256), dtype:float32}
File [...]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)
File [...]/trax/layers/base.py, line 784, in _forward
return f(*xs)
File [...]/trax/layers/core.py, line 704, in <lambda>
return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
File [...]/_src/numpy/lax_numpy.py, line 2154, in mean
normalizer = _axis_size(a, axis)
File [...]/_src/numpy/lax_numpy.py, line 2139, in _axis_size
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
File [...]/jax/_src/util.py, line 391, in maybe_named_axis
return if_named(axis) if named else if_pos(pos)
File [...]/_src/numpy/lax_numpy.py, line 2139, in <lambda>
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
File [...]/_src/lax/parallel.py, line 86, in psum
axis_index_groups=axis_index_groups)
File [...]/_src/lax/parallel.py, line 723, in psum_bind
size = prod([core.axis_frame(name).size for name in named_axes]) # type: ignore
File [...]/_src/lax/parallel.py, line 723, in <listcomp>
size = prod([core.axis_frame(name).size for name in named_axes]) # type: ignore
File [...]/site-packages/jax/core.py, line 1681, in axis_frame
f'unbound axis name: {axis_name}. The following axis names (e.g. defined '
NameError: unbound axis name: Embedding_9088_256. The following axis names (e.g. defined by pmap) are available to collective operations: []