When running the training loop
# Should take around 1.5 minutes
!rm -f ~/model/model.pkl.gz
loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream)
loop.run(10)
I keep getting the following error.
LayerError Traceback (most recent call last)
<ipython-input-42-bf4e14290aae> in <module>
1 # Should take around 1.5 minutes
2 get_ipython().system('rm -f ~/model/model.pkl.gz')
----> 3 loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream)
4 loop.run(10)
<ipython-input-40-c6ec63e508e6> in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
40 train_task,
41 eval_tasks=[eval_task],
---> 42 output_dir=output_dir)
43
44 return loop
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
--> 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, 'init', self._caller,
--> 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/<ipython-input-37-6265e024f03d>, line 63
layer input shapes: (ShapeDtype{shape:(2, 1024), dtype:int64}, ShapeDtype{shape:(2, 1024), dtype:int64}, ShapeDtype{shape:(2, 1024), dtype:int64})
File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/<ipython-input-34-a2ac93ee377e>, line 54
layer input shapes: ShapeDtype{shape:(2, 1024, 4), dtype:float32}
File [...]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer Add (in _forward_abstract):
layer created in file [...]/<ipython-input-34-a2ac93ee377e>, line 54
layer input shapes: (ShapeDtype{shape:(2, 1024, 4), dtype:float32}, ShapeDtype{shape:(4, 1024, 4), dtype:float32})
File [...]/jax/interpreters/partial_eval.py, line 404, in abstract_eval_fun
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals, transform_name)
File [...]/jax/interpreters/partial_eval.py, line 1178, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File [...]/jax/interpreters/partial_eval.py, line 1188, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer Add (in pure_fn):
layer created in file [...]/<ipython-input-34-a2ac93ee377e>, line 54
layer input shapes: (ShapeDtype{shape:(2, 1024, 4), dtype:float32}, ShapeDtype{shape:(4, 1024, 4), dtype:float32})
File [...]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)
File [...]/trax/layers/base.py, line 784, in _forward
return f(*xs)
File [...]/trax/layers/combinators.py, line 843, in <lambda>
return Fn('Add', lambda x0, x1: x0 + x1)
File [...]/site-packages/jax/core.py, line 505, in __add__
def __add__(self, other): return self.aval._add(self, other)
File [...]/_src/numpy/lax_numpy.py, line 5666, in deferring_binary_op
return binary_op(self, other)
File [...]/_src/numpy/lax_numpy.py, line 427, in fn
return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
File [...]/_src/lax/lax.py, line 340, in add
return add_p.bind(x, y)
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1049, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 2115, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File [...]/_src/lax/lax.py, line 2211, in _broadcasting_shape_rule
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
TypeError: add got incompatible shapes for broadcasting: (2, 1024, 4), (4, 1024, 4).
My code has passed all previous unit tests.
Could someone help me? Thanks!
Thanks!