labID : lvedxyqk
I do not know all tests passed successfully until this point but I got error
LayerError Traceback (most recent call last)
in
9 train_task,
10 eval_tasks=[eval_task],
—> 11 output_dir=output_dir)/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 64
layer input shapes: (ShapeDtype{shape:(16, 128), dtype:int64}, ShapeDtype{shape:(16, 128), dtype:int64}, ShapeDtype{shape:(16, 128), dtype:float32})File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)LayerError: Exception passing through layer PrepareAttentionInput (in _forward_abstract):
layer created in file […]/, line 48
layer input shapes: (ShapeDtype{shape:(16, 128, 1024), dtype:float32}, ShapeDtype{shape:(16, 128, 1024), dtype:float32}, ShapeDtype{shape:(16, 128), dtype:int32})File […]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)File […]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)File […]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))LayerError: Exception passing through layer PrepareAttentionInput (in pure_fn):
layer created in file […]/, line 48
layer input shapes: (ShapeDtype{shape:(16, 128, 1024), dtype:float32}, ShapeDtype{shape:(16, 128, 1024), dtype:float32}, ShapeDtype{shape:(16, 128), dtype:int32})File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)File […]/, line 27, in prepare_attention_input
mask = np.where(inputs, np.full_like(inputs, 1), np.full_like(inputs, 0))File […]/<array_function internals>, line 6, in full_like
File […]/numpy/core/numeric.py, line 382, in full_likeres = empty_like(a, dtype=dtype, order=order, subok=subok, shape=shape)
File […]/<array_function internals>, line 6, in empty_like
File […]/site-packages/jax/core.py, line 483, in array
raise TracerArrayConversionError(self)jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(int32[16,128])>with<DynamicJaxprTrace(level=1/0)>
While tracing the function pure_fn at /opt/conda/lib/python3.7/site-packages/trax/layers/base.py:542 for eval_shape, this concrete value was not available in Python because it depends on the value of the argument ‘x’.