hi,
i can’t get my bug in ex7 (sampling_decode)
it seems a problem in the NMTAttn running:
UnfilteredStackTrace Traceback (most recent call last)
in
2 # Run it several times with each setting and see how often the output changes.
----> 3 sampling_decode(“I love languages.”, NMTAttn=model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
in sampling_decode(input_sentence, NMTAttn, temperature, vocab_file, vocab_dir, next_symbol, tokenize, detokenize)
38 # update the current output token by getting the index of the next word (hint: use next_symbol)
—> 39 cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
40 # append the current output token to the list of output tokens
in next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
30 # get the model prediction
—> 31 output, _ = NMTAttn((input_tokens,padded_with_batch))
32 # get log probabilities from the last token output
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in call(self, x, weights, state, rng)
196 state = self.state
→ 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
/opt/conda/lib/python3.7/site-packages/trax/layers/acceleration.py in pure_fn(self, x, weights, state, rng, use_cache)
76 if remainder == 0: # If yes, run the accelerated sublayer.pure_fn.
—> 77 return self._jit_pure_fn(x, weights, state, rng)
78 # If not, pad first.
/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
→ 162 return fun(*args, **kwargs)
163 except Exception as e:
/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
418 device=device, backend=backend, name=flat_fun.name,
→ 419 donated_invars=donated_invars, inline=inline)
420 out_pytree_def = out_tree()
/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1631 def bind(self, fun, *args, **params):
→ 1632 return call_bind(self, fun, *args, **params)
1633
/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1622 tracers = map(top_trace.full_raise, args)
→ 1623 outs = primitive.process(top_trace, fun, tracers, params)
1624 return map(full_lower, apply_todos(env_trace_todo(), outs))
/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1634 def process(self, trace, fun, tracers, params):
→ 1635 return trace.process_call(self, fun, tracers, params)
1636
/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
626 def process_call(self, primitive, f, tracers, params):
→ 627 return primitive.impl(f, *tracers, **params)
628 process_map = process_call
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(failed resolving arguments)
687 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
→ 688 *unsafe_map(arg_spec, args))
689 try:
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
262 else:
→ 263 ans = call(fun, *args)
264 cache[key] = (ans, fun.stores)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable_uncached(fun, device, backend, name, donated_invars, *arg_specs)
759 return lower_xla_callable(fun, device, backend, name, donated_invars,
→ 760 *arg_specs).compile().unsafe_call
761
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
771 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
→ 772 fun, abstract_args, pe.debug_info_final(fun, “jit”))
773 if any(isinstance(c, core.Tracer) for c in consts):
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
1541 with core.new_sublevel():
→ 1542 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1543 del fun, main
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1519 in_tracers = map(trace.new_arg, in_avals)
→ 1520 ans = fun.call_wrapped(*in_tracers)
1521 out_tracers = map(trace.full_raise, ans)
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
→ 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
605 raise LayerError(name, ‘pure_fn’,
→ 606 self._caller, signature(x), trace) from None
607
UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file […]/, line 65
layer input shapes: (ShapeDtype{shape:(1, 5), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})
File […]/trax/layers/base.py, line 707, in setattr
super().setattr(attr, value)
File […]/trax/layers/base.py, line 454, in weights
f’Number of weight elements ({len(weights)}) does not equal the ’
ValueError: Number of weight elements (1) does not equal the number of sublayers (9) in: Serial_in2_out2[
Select[0,1,0,1]_in2_out4
Parallel_in2_out2[
Serial[
Embedding_33300_1024
LSTM_1024
LSTM_1024
]
Serial[
Serial[
ShiftRight(1)
]
Embedding_33300_1024
LSTM_1024
]
]
PrepareAttentionInput_in3_out4
Serial_in4_out2[
Branch_in4_out3[
None
Serial_in4_out2[
_in4_out4
Serial_in4_out2[
Parallel_in3_out3[
Dense_1024
Dense_1024
Dense_1024
]
PureAttention_in4_out2
Dense_1024
]
_in2_out2
]
]
Add_in2
]
Select[0,2]_in3_out2
LSTM_1024
LSTM_1024
Dense_33300
LogSoftmax
].
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
LayerError Traceback (most recent call last)
in
1 # Test the function above. Try varying the temperature setting with values from 0 to 1.
2 # Run it several times with each setting and see how often the output changes.
----> 3 sampling_decode(“I love languages.”, NMTAttn=model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
in sampling_decode(input_sentence, NMTAttn, temperature, vocab_file, vocab_dir, next_symbol, tokenize, detokenize)
37
38 # update the current output token by getting the index of the next word (hint: use next_symbol)
—> 39 cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
40 # append the current output token to the list of output tokens
41 cur_output_tokens.append(cur_output)
in next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
29 padded_with_batch = np.array(padded).reshape((1,int(padded_length)))
30 # get the model prediction
—> 31 output, _ = NMTAttn((input_tokens,padded_with_batch))
32 # get log probabilities from the last token output
33 log_probs = output[0,-1,:]
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in call(self, x, weights, state, rng)
195 self.state = state # Needed if the model wasn’t fully initialized.
196 state = self.state
→ 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
199 return outputs
/opt/conda/lib/python3.7/site-packages/trax/layers/acceleration.py in pure_fn(self, x, weights, state, rng, use_cache)
75 remainder = x.shape[0] % self._n_devices
76 if remainder == 0: # If yes, run the accelerated sublayer.pure_fn.
—> 77 return self._jit_pure_fn(x, weights, state, rng)
78 # If not, pad first.
79 def pad(z):
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
604 name, trace = self._name, _short_traceback(skip=3)
605 raise LayerError(name, ‘pure_fn’,
→ 606 self._caller, signature(x), trace) from None
607
608 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file […]/, line 65
layer input shapes: (ShapeDtype{shape:(1, 5), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})
File […]/trax/layers/base.py, line 707, in setattr
super().setattr(attr, value)
File […]/trax/layers/base.py, line 454, in weights
f’Number of weight elements ({len(weights)}) does not equal the ’
ValueError: Number of weight elements (1) does not equal the number of sublayers (9) in: Serial_in2_out2[
Select[0,1,0,1]_in2_out4
Parallel_in2_out2[
Serial[
Embedding_33300_1024
LSTM_1024
LSTM_1024
]
Serial[
Serial[
ShiftRight(1)
]
Embedding_33300_1024
LSTM_1024
]
]
PrepareAttentionInput_in3_out4
Serial_in4_out2[
Branch_in4_out3[
None
Serial_in4_out2[
_in4_out4
Serial_in4_out2[
Parallel_in3_out3[
Dense_1024
Dense_1024
Dense_1024
]
PureAttention_in4_out2
Dense_1024
]
_in2_out2
]
]
Add_in2
]
Select[0,2]_in3_out2
LSTM_1024
LSTM_1024
Dense_33300
LogSoftmax
].