Hi all,
In Exercise 7 here, I’ve come across a really curious ValueError in the custom test case. This doesn’t show up in unit tests for Exercise 5 or 6, and the unit tests in the subsequent code block also don’t raise this error (though they do fail). I think I’m struggling to understand a few key concepts here:
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-106-2a8688a3fbbc> in <module>
2 # Run it several times with each setting and see how often the output changes.
----> 3 sampling_decode("I love languages.", NMTAttn=model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
<ipython-input-105-8eacaff94bcf> in sampling_decode(input_sentence, NMTAttn, temperature, vocab_file, vocab_dir, next_symbol, tokenize, detokenize)
39 # update the current output token by getting the index of the next word (hint: use next_symbol)
---> 40 cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature=temperature)
41
<ipython-input-103-d95edcfc7ac9> in next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
35 # get the model prediction
---> 36 output, _ = NMTAttn((input_tokens, padded_with_batch))
37
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
196 state = self.state
--> 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
/opt/conda/lib/python3.7/site-packages/trax/layers/acceleration.py in pure_fn(self, x, weights, state, rng, use_cache)
76 if remainder == 0: # If yes, run the accelerated sublayer.pure_fn.
---> 77 return self._jit_pure_fn(x, weights, state, rng)
78 # If not, pad first.
/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
418 device=device, backend=backend, name=flat_fun.__name__,
--> 419 donated_invars=donated_invars, inline=inline)
420 out_pytree_def = out_tree()
/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1631 def bind(self, fun, *args, **params):
-> 1632 return call_bind(self, fun, *args, **params)
1633
/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1622 tracers = map(top_trace.full_raise, args)
-> 1623 outs = primitive.process(top_trace, fun, tracers, params)
1624 return map(full_lower, apply_todos(env_trace_todo(), outs))
/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1634 def process(self, trace, fun, tracers, params):
-> 1635 return trace.process_call(self, fun, tracers, params)
1636
/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
626 def process_call(self, primitive, f, tracers, params):
--> 627 return primitive.impl(f, *tracers, **params)
628 process_map = process_call
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
687 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 688 *unsafe_map(arg_spec, args))
689 try:
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
262 else:
--> 263 ans = call(fun, *args)
264 cache[key] = (ans, fun.stores)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable_uncached(fun, device, backend, name, donated_invars, *arg_specs)
759 return lower_xla_callable(fun, device, backend, name, donated_invars,
--> 760 *arg_specs).compile().unsafe_call
761
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
771 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 772 fun, abstract_args, pe.debug_info_final(fun, "jit"))
773 if any(isinstance(c, core.Tracer) for c in consts):
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
1541 with core.new_sublevel():
-> 1542 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1543 del fun, main
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1519 in_tracers = map(trace.new_arg, in_avals)
-> 1520 ans = fun.call_wrapped(*in_tracers)
1521 out_tracers = map(trace.full_raise, ans)
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
605 raise LayerError(name, 'pure_fn',
--> 606 self._caller, signature(x), trace) from None
607
UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-39-86fefae4e242>, line 64
layer input shapes: (ShapeDtype{shape:(1, 5), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})
File [...]/trax/layers/base.py, line 707, in __setattr__
super().__setattr__(attr, value)
File [...]/trax/layers/base.py, line 454, in weights
f'Number of weight elements ({len(weights)}) does not equal the '
ValueError: Number of weight elements (1) does not equal the number of sublayers (9) in: Serial_in2_out2[
Select[0,1,0,1]_in2_out4
Parallel_in2_out2[
Serial[
Embedding_33300_1024
LSTM_1024
LSTM_1024
]
Serial[
Serial[
ShiftRight(1)
]
Embedding_33300_1024
LSTM_1024
]
]
PrepareAttentionInput_in3_out4
Serial_in4_out2[
Branch_in4_out3[
None
Serial_in4_out2[
_in4_out4
Serial_in4_out2[
Parallel_in3_out3[
Dense_1024
Dense_1024
Dense_1024
]
PureAttention_in4_out2
Dense_1024
]
_in2_out2
]
]
Add_in2
]
Select[0,2]_in3_out2
LSTM_1024
LSTM_1024
Dense_33300
LogSoftmax
].
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
LayerError Traceback (most recent call last)
<ipython-input-106-2a8688a3fbbc> in <module>
1 # Test the function above. Try varying the temperature setting with values from 0 to 1.
2 # Run it several times with each setting and see how often the output changes.
----> 3 sampling_decode("I love languages.", NMTAttn=model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
<ipython-input-105-8eacaff94bcf> in sampling_decode(input_sentence, NMTAttn, temperature, vocab_file, vocab_dir, next_symbol, tokenize, detokenize)
38
39 # update the current output token by getting the index of the next word (hint: use next_symbol)
---> 40 cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature=temperature)
41
42 # append the current output token to the list of output tokens
<ipython-input-103-d95edcfc7ac9> in next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
34
35 # get the model prediction
---> 36 output, _ = NMTAttn((input_tokens, padded_with_batch))
37
38 # get log probabilities slice for the next token
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
195 self.state = state # Needed if the model wasn't fully initialized.
196 state = self.state
--> 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
199 return outputs
/opt/conda/lib/python3.7/site-packages/trax/layers/acceleration.py in pure_fn(self, x, weights, state, rng, use_cache)
75 remainder = x.shape[0] % self._n_devices
76 if remainder == 0: # If yes, run the accelerated sublayer.pure_fn.
---> 77 return self._jit_pure_fn(x, weights, state, rng)
78 # If not, pad first.
79 def pad(z):
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
604 name, trace = self._name, _short_traceback(skip=3)
605 raise LayerError(name, 'pure_fn',
--> 606 self._caller, signature(x), trace) from None
607
608 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-39-86fefae4e242>, line 64
layer input shapes: (ShapeDtype{shape:(1, 5), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})
File [...]/trax/layers/base.py, line 707, in __setattr__
super().__setattr__(attr, value)
File [...]/trax/layers/base.py, line 454, in weights
f'Number of weight elements ({len(weights)}) does not equal the '
ValueError: Number of weight elements (1) does not equal the number of sublayers (9) in: Serial_in2_out2[
Select[0,1,0,1]_in2_out4
Parallel_in2_out2[
Serial[
Embedding_33300_1024
LSTM_1024
LSTM_1024
]
Serial[
Serial[
ShiftRight(1)
]
Embedding_33300_1024
LSTM_1024
]
]
PrepareAttentionInput_in3_out4
Serial_in4_out2[
Branch_in4_out3[
None
Serial_in4_out2[
_in4_out4
Serial_in4_out2[
Parallel_in3_out3[
Dense_1024
Dense_1024
Dense_1024
]
PureAttention_in4_out2
Dense_1024
]
_in2_out2
]
]
Add_in2
]
Select[0,2]_in3_out2
LSTM_1024
LSTM_1024
Dense_33300
LogSoftmax
].
According to trax source code, This error is thrown when the model found a different set of weights in a serial layer than it has sublayers?
# Set sublayer weights.
n_layers = len(self.sublayers)
if len(weights) != n_layers:
raise ValueError(
f'Number of weight elements ({len(weights)}) does not equal the '
f'number of sublayers ({n_layers}) in: {str(self)}.')
for sublayer, sublayer_weights in zip(self.sublayers, weights):
sublayer.weights = sublayer_weights
I’m first of all not clear on how this is being called - and why would the list of weights in any given serial layer suddenly be of a different size during a forward pass than it was during training, and during the unit tests for Exercises 5,6 and 7?