C4_W1_E7 Value Error in forward pass?

Hi all,

In Exercise 7 here, I’ve come across a really curious ValueError in the custom test case. This doesn’t show up in unit tests for Exercise 5 or 6, and the unit tests in the subsequent code block also don’t raise this error (though they do fail). I think I’m struggling to understand a few key concepts here:

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-106-2a8688a3fbbc> in <module>
      2 # Run it several times with each setting and see how often the output changes.
----> 3 sampling_decode("I love languages.", NMTAttn=model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)

<ipython-input-105-8eacaff94bcf> in sampling_decode(input_sentence, NMTAttn, temperature, vocab_file, vocab_dir, next_symbol, tokenize, detokenize)
     39         # update the current output token by getting the index of the next word (hint: use next_symbol)
---> 40         cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature=temperature)
     41 

<ipython-input-103-d95edcfc7ac9> in next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
     35     # get the model prediction
---> 36     output, _ = NMTAttn((input_tokens, padded_with_batch))
     37 

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
    196     state = self.state
--> 197     outputs, new_state = self.pure_fn(x, weights, state, rng)
    198     self.state = new_state

/opt/conda/lib/python3.7/site-packages/trax/layers/acceleration.py in pure_fn(self, x, weights, state, rng, use_cache)
     76     if remainder == 0:  # If yes, run the accelerated sublayer.pure_fn.
---> 77       return self._jit_pure_fn(x, weights, state, rng)
     78     # If not, pad first.

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
    418         device=device, backend=backend, name=flat_fun.__name__,
--> 419         donated_invars=donated_invars, inline=inline)
    420     out_pytree_def = out_tree()

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1631   def bind(self, fun, *args, **params):
-> 1632     return call_bind(self, fun, *args, **params)
   1633 

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1622   tracers = map(top_trace.full_raise, args)
-> 1623   outs = primitive.process(top_trace, fun, tracers, params)
   1624   return map(full_lower, apply_todos(env_trace_todo(), outs))

/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1634   def process(self, trace, fun, tracers, params):
-> 1635     return trace.process_call(self, fun, tracers, params)
   1636 

/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    626   def process_call(self, primitive, f, tracers, params):
--> 627     return primitive.impl(f, *tracers, **params)
    628   process_map = process_call

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
    687   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 688                                *unsafe_map(arg_spec, args))
    689   try:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    262     else:
--> 263       ans = call(fun, *args)
    264       cache[key] = (ans, fun.stores)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable_uncached(fun, device, backend, name, donated_invars, *arg_specs)
    759   return lower_xla_callable(fun, device, backend, name, donated_invars,
--> 760                             *arg_specs).compile().unsafe_call
    761 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    771   jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 772       fun, abstract_args, pe.debug_info_final(fun, "jit"))
    773   if any(isinstance(c, core.Tracer) for c in consts):

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
   1541     with core.new_sublevel():
-> 1542       jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1543     del fun, main

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1519     in_tracers = map(trace.new_arg, in_avals)
-> 1520     ans = fun.call_wrapped(*in_tracers)
   1521     out_tracers = map(trace.full_raise, ans)

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 

UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-39-86fefae4e242>, line 64
  layer input shapes: (ShapeDtype{shape:(1, 5), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})

  File [...]/trax/layers/base.py, line 707, in __setattr__
    super().__setattr__(attr, value)

  File [...]/trax/layers/base.py, line 454, in weights
    f'Number of weight elements ({len(weights)}) does not equal the '

ValueError: Number of weight elements (1) does not equal the number of sublayers (9) in: Serial_in2_out2[
  Select[0,1,0,1]_in2_out4

  Parallel_in2_out2[
    Serial[

      Embedding_33300_1024
      LSTM_1024

      LSTM_1024
    ]

    Serial[
      Serial[

        ShiftRight(1)
      ]

      Embedding_33300_1024
      LSTM_1024

    ]
  ]

  PrepareAttentionInput_in3_out4
  Serial_in4_out2[

    Branch_in4_out3[
      None

      Serial_in4_out2[
        _in4_out4

        Serial_in4_out2[
          Parallel_in3_out3[

            Dense_1024
            Dense_1024

            Dense_1024
          ]

          PureAttention_in4_out2
          Dense_1024

        ]
        _in2_out2

      ]
    ]

    Add_in2
  ]

  Select[0,2]_in3_out2
  LSTM_1024

  LSTM_1024
  Dense_33300

  LogSoftmax
].

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

LayerError                                Traceback (most recent call last)
<ipython-input-106-2a8688a3fbbc> in <module>
      1 # Test the function above. Try varying the temperature setting with values from 0 to 1.
      2 # Run it several times with each setting and see how often the output changes.
----> 3 sampling_decode("I love languages.", NMTAttn=model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)

<ipython-input-105-8eacaff94bcf> in sampling_decode(input_sentence, NMTAttn, temperature, vocab_file, vocab_dir, next_symbol, tokenize, detokenize)
     38 
     39         # update the current output token by getting the index of the next word (hint: use next_symbol)
---> 40         cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature=temperature)
     41 
     42         # append the current output token to the list of output tokens

<ipython-input-103-d95edcfc7ac9> in next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
     34 
     35     # get the model prediction
---> 36     output, _ = NMTAttn((input_tokens, padded_with_batch))
     37 
     38     # get log probabilities slice for the next token

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
    195       self.state = state  # Needed if the model wasn't fully initialized.
    196     state = self.state
--> 197     outputs, new_state = self.pure_fn(x, weights, state, rng)
    198     self.state = new_state
    199     return outputs

/opt/conda/lib/python3.7/site-packages/trax/layers/acceleration.py in pure_fn(self, x, weights, state, rng, use_cache)
     75       remainder = x.shape[0] % self._n_devices
     76     if remainder == 0:  # If yes, run the accelerated sublayer.pure_fn.
---> 77       return self._jit_pure_fn(x, weights, state, rng)
     78     # If not, pad first.
     79     def pad(z):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    604       name, trace = self._name, _short_traceback(skip=3)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 
    608   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-39-86fefae4e242>, line 64
  layer input shapes: (ShapeDtype{shape:(1, 5), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})

  File [...]/trax/layers/base.py, line 707, in __setattr__
    super().__setattr__(attr, value)

  File [...]/trax/layers/base.py, line 454, in weights
    f'Number of weight elements ({len(weights)}) does not equal the '

ValueError: Number of weight elements (1) does not equal the number of sublayers (9) in: Serial_in2_out2[
  Select[0,1,0,1]_in2_out4

  Parallel_in2_out2[
    Serial[

      Embedding_33300_1024
      LSTM_1024

      LSTM_1024
    ]

    Serial[
      Serial[

        ShiftRight(1)
      ]

      Embedding_33300_1024
      LSTM_1024

    ]
  ]

  PrepareAttentionInput_in3_out4
  Serial_in4_out2[

    Branch_in4_out3[
      None

      Serial_in4_out2[
        _in4_out4

        Serial_in4_out2[
          Parallel_in3_out3[

            Dense_1024
            Dense_1024

            Dense_1024
          ]

          PureAttention_in4_out2
          Dense_1024

        ]
        _in2_out2

      ]
    ]

    Add_in2
  ]

  Select[0,2]_in3_out2
  LSTM_1024

  LSTM_1024
  Dense_33300

  LogSoftmax
].

According to trax source code, This error is thrown when the model found a different set of weights in a serial layer than it has sublayers?

      # Set sublayer weights.
      n_layers = len(self.sublayers)
      if len(weights) != n_layers:
        raise ValueError(
            f'Number of weight elements ({len(weights)}) does not equal the '
            f'number of sublayers ({n_layers}) in: {str(self)}.')
      for sublayer, sublayer_weights in zip(self.sublayers, weights):
        sublayer.weights = sublayer_weights

I’m first of all not clear on how this is being called - and why would the list of weights in any given serial layer suddenly be of a different size during a forward pass than it was during training, and during the unit tests for Exercises 5,6 and 7?

Hi @Karan_Abrol

If the unit tests fail for the previous exercises you should definitely look at those first. The error for exercise 7 could be caused my multiple reasons, though it is indicative that the model is not of the right dimensions.

Let me know if you will be able to solve the problem.

Hey @arvyzukai

Apologies for not being clear - 5 and 6 run clean and pass, but 7 fails without this error. However, subsequent test cases do fail with this error. The model itself (ex 4) matches the sample model exactly, and unit tests pass:

Serial_in2_out2[
  Select[0,1,0,1]_in2_out4
  Parallel_in2_out2[
    Serial[
      Embedding_33300_1024
      LSTM_1024
      LSTM_1024
    ]
    Serial[
      Serial[
        ShiftRight(1)
      ]
      Embedding_33300_1024
      LSTM_1024
    ]
  ]
  PrepareAttentionInput_in3_out4
  Serial_in4_out2[
    Branch_in4_out3[
      None
      Serial_in4_out2[
        _in4_out4
        Serial_in4_out2[
          Parallel_in3_out3[
            Dense_1024
            Dense_1024
            Dense_1024
          ]
          PureAttention_in4_out2
          Dense_1024
        ]
        _in2_out2
      ]
    ]
    Add_in2
  ]
  Select[0,2]_in3_out2
  LSTM_1024
  LSTM_1024
  Dense_33300
  LogSoftmax
]

I’m curious to know, why would a serial model suddenly have only one set of weights rather than the number of sublayers? Could this happen as a result of a mistake during the train loop, or forward pass?

Looks like the problem was, in fact in Ex 4 - But I don’t understand why it was a problem.

# Step 2: copy input tokens and target tokens as they will be needed later.

In the solution to this step, I used the parameter n_in = 2, hoping to signify that two inputs go in. Removing this fixed the problem. I’m not sure why it would be wrong to set n_in=2.

Hi @Karan_Abrol

Because technically the tuple is a single input (even though it has two entries). For example, a tuple with 6 elements is technically a single input, but sometimes the code is left for different interpretations (for more general use cases) and you for example, could check if the input is a tuple of 6 elements, if yes then the further code can interpret that as 6 different inputs, otherwise the code should treat the input as a single input…