---------------------------------------------------------------------------
LayerError Traceback (most recent call last)
<ipython-input-40-d1e9ffd2f1b1> in <module>
1 # UNIT TEST
2 # test training_loop
----> 3 w2_tests.test_training_loop(training_loop, TransformerLM)
~/work/w2_tests.py in test_training_loop(target, TransformerLM)
813 os.remove("~/model/model.pkl.gz")
814
--> 815 output_loop = target(TransformerLM, my_gen(), my_gen())
816
817 try:
<ipython-input-39-fb48a407979c> in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
41 train_task,
42 eval_tasks=[eval_task],
---> 43 output_dir=output_dir)
44
45 return loop
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
--> 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, 'init', self._caller,
--> 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/<ipython-input-36-3a24aa910556>, line 65
layer input shapes: (ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64})
File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/<ipython-input-33-c8371ee5498b>, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}
File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))
LayerError: Exception passing through layer Branch (in init):
layer created in file [...]/<ipython-input-33-c8371ee5498b>, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}
File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))
LayerError: Exception passing through layer Parallel (in init):
layer created in file [...]/<ipython-input-33-c8371ee5498b>, line 54
layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})
File [...]/trax/layers/combinators.py, line 226, in init_weights_and_state
in zip(self.sublayers, sublayer_signatures)]
File [...]/trax/layers/combinators.py, line 225, in <listcomp>
for layer, signature
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/<ipython-input-33-c8371ee5498b>, line 54
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}
File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
sublayer.init(inputs, use_cache=True))
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/<ipython-input-30-4966a376bb95>, line 50
layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}
File [...]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer DotProductAttn (in _forward_abstract):
layer created in file [...]/<ipython-input-30-4966a376bb95>, line 44
layer input shapes: (ShapeDtype{shape:(2, 1250, 2), dtype:float32}, ShapeDtype{shape:(2, 1250, 2), dtype:float32}, ShapeDtype{shape:(2, 1250, 2), dtype:float32})
File [...]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)
File [...]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File [...]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer DotProductAttn (in pure_fn):
layer created in file [...]/<ipython-input-30-4966a376bb95>, line 44
layer input shapes: (ShapeDtype{shape:(2, 1250, 2), dtype:float32}, ShapeDtype{shape:(2, 1250, 2), dtype:float32}, ShapeDtype{shape:(2, 1250, 2), dtype:float32})
File [...]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)
File [...]/trax/layers/base.py, line 784, in _forward
return f(*xs)
File [...]/<ipython-input-24-4cfe8a154e5d>, line 21, in dot_product_self_attention
mask = np.tril(jnp.ones( (mask_size,mask_size)))
File [...]/<__array_function__ internals>, line 6, in tril
File [...]/numpy/lib/twodim_base.py, line 433, in tril
m = asanyarray(m)
File [...]/numpy/core/_asarray.py, line 136, in asanyarray
return array(a, dtype, copy=False, order=order, subok=True)
File [...]/site-packages/jax/core.py, line 483, in __array__
raise TracerArrayConversionError(self)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[1250,1250])>with<DynamicJaxprTrace(level=1/0)>
While tracing the function pure_fn at /opt/conda/lib/python3.7/site-packages/trax/layers/base.py:542 for eval_shape, this value became a tracer due to JAX operations on these lines:
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
from line <ipython-input-24-4cfe8a154e5d>:21 (dot_product_self_attention)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
I couldn’t figure out what is wrong.Please help!