Getting below error for # UNQ_C9 def next_symbol(cur_output_tokens, model)

I am getting below error for UNQ_C9
def next_symbol(cur_output_tokens, model). Not able to figure out what can I do?

LayerError                                Traceback (most recent call last)
<ipython-input-46-9431ff725bfd> in <module>
      1 # Test it out!
      2 sentence_test_nxt_symbl = "I want to fly in the sky."
----> 3 detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)+[0], model)])

<ipython-input-45-8a42ea458228> in next_symbol(cur_output_tokens, model)
     26 
     27     # model expects a tuple containing two padded tensors (with batch)
---> 28     output, _ = model((padded_with_batch, padded_with_batch))
     29     # HINT: output has shape (1, padded_length, vocab_size)
     30     # To get log_probs you need to index output wih 0 in the first dim

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
    195       self.state = state  # Needed if the model wasn't fully initialized.
    196     state = self.state
--> 197     outputs, new_state = self.pure_fn(x, weights, state, rng)
    198     self.state = new_state
    199     return outputs

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    604       name, trace = self._name, _short_traceback(skip=3)
    605       raise LayerError(name, 'pure_fn',
--> 606                        self._caller, signature(x), trace) from None
    607 
    608   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-37-f2875c2dc46e>, line 64
  layer input shapes: (ShapeDtype{shape:(1, 16), dtype:int64}, ShapeDtype{shape:(1, 16), dtype:int64})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Embedding_33300_512 (in pure_fn):
  layer created in file [...]/<ipython-input-37-f2875c2dc46e>, line 38
  layer input shapes: ShapeDtype{shape:(1, 16), dtype:int32}

  File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
    y = forward(self, x, *args, **kwargs)

  File [...]/trax/layers/core.py, line 182, in forward
    embedded = jnp.take(self.weights, x, axis=0, mode='clip')

  File [...]/_src/numpy/lax_numpy.py, line 4621, in take
    slice_sizes=tuple(slice_sizes))

  File [...]/_src/lax/lax.py, line 978, in gather
    slice_sizes=canonicalize_shape(slice_sizes))

  File [...]/site-packages/jax/core.py, line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/site-packages/jax/core.py, line 606, in process_primitive
    return primitive.impl(*tracers, **params)

  File [...]/jax/interpreters/xla.py, line 231, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)

  File [...]/jax/_src/util.py, line 186, in wrapper
    return cached(config._trace_context(), *args, **kwargs)

  File [...]/jax/_src/util.py, line 179, in cached
    return f(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 256, in xla_primitive_callable
    aval_out = prim.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2115, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 4455, in _gather_shape_rule
    raise TypeError(f"Slice size at index {i} in gather op is out of range, "

TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 0 + 1), got 1.

Did you figure it out? I have a similar problem on another question