I am getting below error for UNQ_C9
def next_symbol(cur_output_tokens, model). Not able to figure out what can I do?
LayerError Traceback (most recent call last)
<ipython-input-46-9431ff725bfd> in <module>
1 # Test it out!
2 sentence_test_nxt_symbl = "I want to fly in the sky."
----> 3 detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)+[0], model)])
<ipython-input-45-8a42ea458228> in next_symbol(cur_output_tokens, model)
26
27 # model expects a tuple containing two padded tensors (with batch)
---> 28 output, _ = model((padded_with_batch, padded_with_batch))
29 # HINT: output has shape (1, padded_length, vocab_size)
30 # To get log_probs you need to index output wih 0 in the first dim
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
195 self.state = state # Needed if the model wasn't fully initialized.
196 state = self.state
--> 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
199 return outputs
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
604 name, trace = self._name, _short_traceback(skip=3)
605 raise LayerError(name, 'pure_fn',
--> 606 self._caller, signature(x), trace) from None
607
608 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-37-f2875c2dc46e>, line 64
layer input shapes: (ShapeDtype{shape:(1, 16), dtype:int64}, ShapeDtype{shape:(1, 16), dtype:int64})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Embedding_33300_512 (in pure_fn):
layer created in file [...]/<ipython-input-37-f2875c2dc46e>, line 38
layer input shapes: ShapeDtype{shape:(1, 16), dtype:int32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 182, in forward
embedded = jnp.take(self.weights, x, axis=0, mode='clip')
File [...]/_src/numpy/lax_numpy.py, line 4621, in take
slice_sizes=tuple(slice_sizes))
File [...]/_src/lax/lax.py, line 978, in gather
slice_sizes=canonicalize_shape(slice_sizes))
File [...]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/site-packages/jax/core.py, line 606, in process_primitive
return primitive.impl(*tracers, **params)
File [...]/jax/interpreters/xla.py, line 231, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File [...]/jax/_src/util.py, line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File [...]/jax/_src/util.py, line 179, in cached
return f(*args, **kwargs)
File [...]/jax/interpreters/xla.py, line 256, in xla_primitive_callable
aval_out = prim.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 2115, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File [...]/_src/lax/lax.py, line 4455, in _gather_shape_rule
raise TypeError(f"Slice size at index {i} in gather op is out of range, "
TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 0 + 1), got 1.