I’m having issues with exercise 6train_model
in the assignment.
I’m passing all tests until exercise 6, but I’m getting the error below regarding data shape.
This is after implementing tl.Mean(axis=1). Anyone figured this out?
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file […]/trax/supervised/training.py, line 1033
layer input shapes: (ShapeDtype{shape:(8, 13), dtype:int32}, ShapeDtype{shape:(40,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})
File […]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer CrossEntropyLoss (in pure_fn):
layer created in file […]/, line 12
layer input shapes: (ShapeDtype{shape:(8, 2), dtype:float32}, ShapeDtype{shape:(40,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})
File […]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer _CrossEntropy (in pure_fn):
layer created in file […]/, line 12
layer input shapes: (ShapeDtype{shape:(8, 2), dtype:float32}, ShapeDtype{shape:(40,), dtype:int32})
File […]/trax/layers/base.py, line 743, in forward
raw_output = self._forward_fn(inputs)
File […]/trax/layers/base.py, line 784, in _forward
return f(*xs)
File […]/trax/layers/metrics.py, line 581, in f
return -1.0 * jnp.sum(model_output * target_distribution, axis=-1)
File […]/site-packages/jax/core.py, line 506, in mul
def mul(self, other): return self.aval._mul(self, other)
File […]/_src/numpy/lax_numpy.py, line 5819, in deferring_binary_op
return binary_op(self, other)
File […]/src/numpy/lax_numpy.py, line 431, in fn
return lax_fn(x1, x2) if x1.dtype != bool else bool_lax_fn(x1, x2)
File […]/_src/lax/lax.py, line 348, in mul
return mul_p.bind(x, y)
File […]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File […]/jax/interpreters/ad.py, line 274, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File […]/jax/interpreters/ad.py, line 449, in standard_jvp
val_out = primitive.bind(*primals, **params)
File […]/site-packages/jax/core.py, line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File […]/jax/interpreters/partial_eval.py, line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File […]/_src/lax/lax.py, line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File […]/_src/lax/lax.py, line 2221, in _broadcasting_shape_rule
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
TypeError: mul got incompatible shapes for broadcasting: (8, 2), (40, 2).