Layer Error in Section 3.3 (Loop fcn)

Hi,

running code in Section 3.3 (Loop Fcn) results in an error. All tests are passed up to this point.

Can someone help?


LayerError Traceback (most recent call last)
in
9 train_task,
10 eval_tasks=[eval_task],
—> 11 output_dir=output_dir)

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
246 if not use_memory_efficient_trainer:
247 if _is_uninitialized(self._model):
→ 248 self._model.init(self._batch_signature)
249 self._eval_model.rng = self.new_rng()
250 if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
309 name, trace = self._name, _short_traceback(skip=3)
310 raise LayerError(name, ‘init’, self._caller,
→ 311 input_signature, trace) from None
312
313 def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
layer created in file […]/, line 64
layer input shapes: (ShapeDtype{shape:(128, 16), dtype:int64}, ShapeDtype{shape:(128, 16), dtype:int64}, ShapeDtype{shape:(128, 16), dtype:float32})

File […]/trax/layers/combinators.py, line 106, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer PrepareAttentionInput (in _forward_abstract):
layer created in file […]/, line 48
layer input shapes: (ShapeDtype{shape:(128, 16, 1024), dtype:float32}, ShapeDtype{shape:(128, 16, 1024), dtype:float32}, ShapeDtype{shape:(128, 16), dtype:int32})

File […]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

File […]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer PrepareAttentionInput (in pure_fn):
layer created in file […]/, line 48
layer input shapes: (ShapeDtype{shape:(128, 16, 1024), dtype:float32}, ShapeDtype{shape:(128, 16, 1024), dtype:float32}, ShapeDtype{shape:(128, 16), dtype:int32})

File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)

File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)

File […]/, line 28, in prepare_attention_input
mask = mask.at[np.equal(inputs, 0)].set(0)

File […]/site-packages/jax/core.py, line 483, in array
raise TracerArrayConversionError(self)

jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(int32[128,16])>with<DynamicJaxprTrace(level=1/0)>
While tracing the function pure_fn at /opt/conda/lib/python3.7/site-packages/trax/layers/base.py:542 for eval_shape, this concrete value was not available in Python because it depends on the value of the argument ‘x’.

See JAX Errors — JAX documentation

Hi Fabian_Muller,

Were you able to resolve this? If not, feel free to send me your notebook as an attachment in a direct message so I can have a look at what’s going on.

Hi,
I’m getting an almost identical error with my code.
Has this been eventually solved? All my tests up to this point are passing.
Thank you!

LayerError                                Traceback (most recent call last)
<ipython-input-25-0c4a3449f2b4> in <module>
      9                               train_task,
     10                               eval_tasks=[eval_task],
---> 11                               output_dir=output_dir)

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    246     if not use_memory_efficient_trainer:
    247       if _is_uninitialized(self._model):
--> 248         self._model.init(self._batch_signature)
    249       self._eval_model.rng = self.new_rng()
    250       if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
    309       name, trace = self._name, _short_traceback(skip=3)
    310       raise LayerError(name, 'init', self._caller,
--> 311                        input_signature, trace) from None
    312 
    313   def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-18-d3261b6bd1aa>, line 64
  layer input shapes: (ShapeDtype{shape:(64, 32), dtype:int64}, ShapeDtype{shape:(64, 32), dtype:int64}, ShapeDtype{shape:(64, 32), dtype:float32})

  File [...]/trax/layers/combinators.py, line 106, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer PrepareAttentionInput (in _forward_abstract):
  layer created in file [...]/<ipython-input-18-d3261b6bd1aa>, line 48
  layer input shapes: (ShapeDtype{shape:(64, 32, 1024), dtype:float32}, ShapeDtype{shape:(64, 32, 1024), dtype:float32}, ShapeDtype{shape:(64, 32), dtype:int32})

  File [...]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
    lu.wrap_init(fun, params), avals, debug_info)

  File [...]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer PrepareAttentionInput (in pure_fn):
  layer created in file [...]/<ipython-input-18-d3261b6bd1aa>, line 48
  layer input shapes: (ShapeDtype{shape:(64, 32, 1024), dtype:float32}, ShapeDtype{shape:(64, 32, 1024), dtype:float32}, ShapeDtype{shape:(64, 32), dtype:int32})

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/<ipython-input-16-3168491a0f1a>, line 27, in prepare_attention_input
    mask = fastnp.array([[1 if x>0 else 0 for x in y] for y in inputs])

  File [...]/<ipython-input-16-3168491a0f1a>, line 27, in <listcomp>
    mask = fastnp.array([[1 if x>0 else 0 for x in y] for y in inputs])

  File [...]/<ipython-input-16-3168491a0f1a>, line 27, in <listcomp>
    mask = fastnp.array([[1 if x>0 else 0 for x in y] for y in inputs])

  File [...]/site-packages/jax/core.py, line 549, in __bool__
    def __bool__(self): return self.aval._bool(self)

  File [...]/site-packages/jax/core.py, line 1000, in error
    raise ConcretizationTypeError(arg, fname_context)

jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function. 

While tracing the function pure_fn at /opt/conda/lib/python3.7/site-packages/trax/layers/base.py:542 for eval_shape, this concrete value was not available in Python because it depends on the value of the argument 'x'.


See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
1 Like

Hi fatoddsun,

Did you manage to resolve this? If not, feel free to send me your notebook as an attachment to a direct message so I can have a look.

Hello,

I have run into this same “layer error” in section3.3 (loop function). Please help me think through this.

Thanks

Hi rameshvasu,

The previously reported issues in this thread were different in nature. The fastest way for me to have a look at what the issue is in your case is for you to send me your notebook as an attachment to a direct message. I can then have a look at what’s going on.

Hi,

Could anyone solve this issue? I cannot proceed without solving this. My code is also running till that point by passing all unit tests.

I am looking forward to hearing from you!

Thanks for getting back @reinoudbosch. I figured it out myself. Using fastnp.where() and NOT np.where() in prepare_attention_input() function made the difference for me.

Hi @Mahammad_Namazov,

Using fastnp.where() and NOT np.where() in prepare_attention_input() function made the difference for me. Hope this helps you.

1 Like

Exactly! I solved it the same day. Error message is quite clear actually, when the provided link checked, but I missed it at the beginning :slight_smile: . Thanks for your answer!