CBOW using trax

I am trying to create cbow model using trax but getting error

my model
def cbow_model(vocab_size, embedding_size, mode=“train”):
return tl.Serial(
tl.Dense(vocab_size),
tl.Dense(embedding_size),
tl.Relu(),
tl.Dense(vocab_size),
tl.Softmax(axis=-1),
)
input, output is same as batch input in assignment

Error:
LayerError: Exception passing through layer CrossEntropyLoss (in init):
layer created in file […]/, line 6
layer input shapes: (ShapeDtype{shape:(128, 5775), dtype:float32}, ShapeDtype{shape:(128, 5775), dtype:float32})

File […]/trax/layers/combinators.py, line 108, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer _CrossEntropy (in _forward_abstract):
layer created in file […]/, line 6
layer input shapes: (ShapeDtype{shape:(128, 5775), dtype:float32}, ShapeDtype{shape:(128, 5775), dtype:float32})

File […]/jax/interpreters/partial_eval.py, line 662, in abstract_eval_fun
lu.wrap_init(fun, params), avals, debug_info)

File […]/jax/_src/profiler.py, line 313, in wrapper
return func(*args, **kwargs)

File […]/jax/interpreters/partial_eval.py, line 1985, in trace_to_jaxpr_dynamic
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)

File […]/jax/interpreters/partial_eval.py, line 2001, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)

File […]/dist-packages/jax/linear_util.py, line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File […]/dist-packages/jax/linear_util.py, line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer _CrossEntropy (in pure_fn):
layer created in file […]/, line 6
layer input shapes: (ShapeDtype{shape:(128, 5775), dtype:float32}, ShapeDtype{shape:(128, 5775), dtype:float32})

File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)

File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)

File […]/trax/layers/metrics.py, line 582, in f
return -1.0 * jnp.sum(model_output * target_distribution, axis=-1)

File […]/dist-packages/jax/core.py, line 591, in mul
def mul(self, other): return self.aval._mul(self, other)

File […]/_src/numpy/lax_numpy.py, line 4710, in deferring_binary_op
return binary_op(*args)

File […]/_src/numpy/ufuncs.py, line 83, in fn
x1, x2 = _promote_args(numpy_fn.name, x1, x2)

File […]/_src/numpy/util.py, line 356, in _promote_args
return _promote_shapes(fun_name, *_promote_dtypes(*args))

File […]/_src/numpy/util.py, line 249, in _promote_shapes
result_rank = len(lax.broadcast_shapes(*shapes))

ValueError: Incompatible shapes for broadcasting: shapes=[(128, 5775), (128, 5775, 5775)]

Hi @Abhishek_Kapoor1 !
This error is raised because of incompatible input and desired shapes. How do you prepare and feed data to your model? Can you please send me the complete code for preparing input data and calling this function with the input batch?