C4 W1 3.3 Loop Error with Fn

I checked my architecture, and it matches exactly with the required architecture. Still when running 3.3 I get the following errors:
I can see it stated that it has to do with my

tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4),
but the instruction states :

Note: Pass the prepare_attention_input function as the f parameter in tl.Fn without any arguments or parenthesis.

Also, when I submit my code or grading, all the codes up to 3.3 pass. So I don’t understand where the issue is. By the way, I checked multiple times my architecture against the expected architecture.
So what am I doing wrong?

Any help will be appreciated

LayerError                                Traceback (most recent call last)
<ipython-input-49-0c4a3449f2b4> in <module>
      9                               train_task,
     10                               eval_tasks=[eval_task],
---> 11                               output_dir=output_dir)

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    246     if not use_memory_efficient_trainer:
    247       if _is_uninitialized(self._model):
--> 248         self._model.init(self._batch_signature)
    249       self._eval_model.rng = self.new_rng()
    250       if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
    309       name, trace = self._name, _short_traceback(skip=3)
    310       raise LayerError(name, 'init', self._caller,
--> 311                        input_signature, trace) from None
    312 
    313   def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-33-d3261b6bd1aa>, line 64
  layer input shapes: (ShapeDtype{shape:(16, 128), dtype:int64}, ShapeDtype{shape:(16, 128), dtype:int64}, ShapeDtype{shape:(16, 128), dtype:float32})

  File [...]/trax/layers/combinators.py, line 106, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer PrepareAttentionInput (in _forward_abstract):
  layer created in file [...]/<ipython-input-33-d3261b6bd1aa>, line 48
  layer input shapes: (ShapeDtype{shape:(16, 128, 1024), dtype:float32}, ShapeDtype{shape:(16, 128, 1024), dtype:float32}, ShapeDtype{shape:(16, 128), dtype:int32})

  File [...]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
    lu.wrap_init(fun, params), avals, debug_info)

  File [...]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer PrepareAttentionInput (in pure_fn):
  layer created in file [...]/<ipython-input-33-d3261b6bd1aa>, line 48
  layer input shapes: (ShapeDtype{shape:(16, 128, 1024), dtype:float32}, ShapeDtype{shape:(16, 128, 1024), dtype:float32}, ShapeDtype{shape:(16, 128), dtype:int32})

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/<ipython-input-17-107d40b1f8c6>, line 27, in prepare_attention_input
    mask = np.where(inputs > 0,1,0)

  File [...]/<__array_function__ internals>, line 6, in where
  File [...]/site-packages/jax/core.py, line 483, in __array__

    raise TracerArrayConversionError(self)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[16,128])>with<DynamicJaxprTrace(level=1/0)>

While tracing the function pure_fn at /opt/conda/lib/python3.7/site-packages/trax/layers/base.py:542 for eval_shape, this concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Got the same thing! My function ‘prepare_attention_input’ got all tests passed, but I think it has something wrong when using at NMTAttn function.

@D_Ben @joaopedrovtp - I had the same issue. I think it has to do with how you create your mask matrix. The last line of @D_Ben’s error trace was

this concrete value was not available in Python because it depends on the value of the argument ‘x’.

This makes me wonder if you’re using some sort of implicit or explicit loop to zero out elements of the mask matrix. You might want to try vectorizing with something like

mask = fastnp.where(…)

1 Like

Thank You @Steven1 for the hint. I was using np.where()

@D_Ben Did you try fastnp?

which package has fastnp? @Steven1
I’m using Cloudera environment

It’s in the original list of imports

from termcolor import colored
import random
import numpy as np

import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

Both np and fastnp are imported

Thank You @Steven1!! I will try it right now.

Cool @D_Ben - Let me know if it works

1 Like

Thank you @Steven1 it worked :grinning:

thanks @Steven1, it worked! They should give a hint about using fastnp.where(…), because you can do the exact same thing using other way arround, like jax.numpy.ndarray.at.

@joaopedrovtp - Yeah, if you look thru the lists of questions, you can find at least 2 others where this came up.