Hi @wnyforever I have the same issue as @Elemento. Look ak my implementations, thanks for your help
GRADED FUNCTION: train_model
def train_model(classifier, train_task, eval_task, n_steps, output_dir):
classifier - the model you are building
train_task - Training task
eval_task - Evaluation task. Received as a list.
n_steps - the evaluation steps
output_dir - folder to save your files
trainer - trax trainer
rnd.seed(31) # Do NOT modify this random seed. This makes the notebook easier to replicate
### START CODE HERE (Replace instances of 'None' with your code) ###
training_loop = training.Loop(
classifier, # The learning model
train_task, # The training task
eval_tasks=eval_task, # The evaluation task
output_dir=output_dir, # The output directory
random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
training_loop.run(n_steps = n_steps)
# Return the training_loop, since it has the model.
return training_loop
Do not modify this cell.
Take a look on how the eval_task is inside square brackets and
take that into account for you train_model implementation
training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)
LayerError Traceback (most recent call last)
2 # Take a look on how the eval_task is inside square brackets and
3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)
in train_model(classifier, train_task, eval_task, n_steps, output_dir)
20 eval_tasks=eval_task, # The evaluation task
21 output_dir=output_dir, # The output directory
—> 22 random_seed=31 # Do not modify this random seed in order to ensure reproducibility and for grading purposes.
23 )
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
305 self._rjust_len = max(map(len, loss_names + metric_names))
306 self._evaluator_per_task = tuple(
→ 307 self._init_evaluator(eval_task) for eval_task in self._eval_tasks)
309 if self._output_dir is None:
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in (.0)
305 self._rjust_len = max(map(len, loss_names + metric_names))
306 self._evaluator_per_task = tuple(
→ 307 self._init_evaluator(eval_task) for eval_task in self._eval_tasks)
309 if self._output_dir is None:
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _init_evaluator(self, eval_task)
364 “”“Initializes the per-task evaluator.”“”
365 model_with_metrics = _model_with_metrics(
→ 366 self._eval_model, eval_task)
367 if self._use_memory_efficient_trainer:
368 return _Evaluator(
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _model_with_metrics(model, eval_task)
1047 “”"
1048 return _model_with_ends(
→ 1049 model, eval_task.metrics, shapes.signature(eval_task.sample_batch)
1050 )
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _model_with_ends(model, end_layers, batch_signature)
1029 metrics_layer = tl.Branch(*end_layers)
1030 metrics_input_signature = model.output_signature(batch_signature)
→ 1031 _, _ = metrics_layer.init(metrics_input_signature)
1033 model_with_metrics = tl.Serial(model, metrics_layer)
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):
LayerError: Exception passing through layer Branch (in init):
layer created in file […]/trax/supervised/training.py, line 1029
layer input shapes: (ShapeDtype{shape:(13, 2), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})
File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer Parallel (in _forward_abstract):
layer created in file […]/trax/supervised/training.py, line 1029
layer input shapes: (ShapeDtype{shape:(13, 2), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(13, 2), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})
File […]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)
File […]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File […]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer Parallel (in pure_fn):
layer created in file […]/trax/supervised/training.py, line 1029
layer input shapes: (ShapeDtype{shape:(13, 2), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(13, 2), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})
File […]/trax/layers/combinators.py, line 211, in forward
sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True)
LayerError: Exception passing through layer WeightedCategoryCrossEntropy (in pure_fn):
layer created in file […]/, line 21
layer input shapes: (ShapeDtype{shape:(13, 2), dtype:float32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})
File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)
File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)
File […]/trax/layers/metrics.py, line 273, in f
model_output, targets, label_smoothing)
File […]/trax/layers/metrics.py, line 649, in _category_cross_entropy
return - jnp.sum(target_distributions * model_log_distributions, axis=-1)
File […]/site-packages/jax/core.py, line 506, in mul
def mul(self, other): return self.aval._mul(self, other)
File […]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
return binary_op(self, other)
File […]/src/numpy/lax_numpy.py, line 431, in fn
return lax_fn(x1, x2) if x1.dtype != bool else bool_lax_fn(x1, x2)
File […]/_src/lax/lax.py, line 348, in mul
return mul_p.bind(x, y)
File […]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File […]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File […]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File […]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
TypeError: mul got incompatible shapes for broadcasting: (16, 2), (13, 2).
Here is my classifier:
def classifier(vocab_size=9088, embedding_dim=256, output_dim=2, mode=‘train’):
### START CODE HERE (Replace instances of 'None' with your code) ###
# create embedding layer
embed_layer = tl.Embedding(
vocab_size=vocab_size, # Size of the vocabulary
d_feature=embedding_dim # Embedding dimension
# Create a mean layer, to create an "average" word embedding
mean_layer = tl.Mean(axis=0)
# Create a dense layer, one unit for each output
dense_output_layer = tl.Dense(n_units = output_dim)
# Use tl.Serial to combine all layers
# and create the classifier
# of type trax.layers.combinators.Serial
model = tl.Serial(
embed_layer, # embedding layer
mean_layer, # mean layer
dense_output_layer # dense output layer
# return the model of type
return model