Assignment - UNQ_C4 - JAX Error

When I execute the cell below UNQ_C4, I get the following error:

---------------------------------------------------------------------------UnfilteredStackTrace Traceback (most recent call last) in 2 training_loop = train_model(Siamese, TripletLoss, train_generator, val_generator)----> 3 /opt/conda/lib/python3.7/site-packages/trax/supervised/ in run(self, n_steps) 434 β†’ 435 loss, optimizer_metrics = self._run_one_step(task_index, task_changed) 436 /opt/conda/lib/python3.7/site-packages/trax/supervised/ in _run_one_step(self, task_index, task_changed) 632 (loss, stats) = trainer.one_step( β†’ 633 batch, rng, step=step, learning_rate=learning_rate 634 ) /opt/conda/lib/python3.7/site-packages/trax/optimizers/ in one_step(self, batch, rng, step, learning_rate) 147 (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( β†’ 148 (weights, self._slots), step, self._opt_params, batch, state, rng) 149 /opt/conda/lib/python3.7/site-packages/jax/_src/ in reraise_with_filtered_traceback(*args, **kwargs) 182 try:–> 183 return fun(*args, **kwargs) 184 except Exception as e: /opt/conda/lib/python3.7/site-packages/jax/_src/ in cache_miss(*args, **kwargs) 426 device=device, backend=backend,,–> 427 donated_invars=donated_invars, inline=inline) 428 out_pytree_def = out_tree() /opt/conda/lib/python3.7/site-packages/jax/ in bind(self, fun, *args, **params) 1559 def bind(self, fun, *args, **params):-> 1560 return call_bind(self, fun, *args, **params) 1561 /opt/conda/lib/python3.7/site-packages/jax/ in call_bind(primitive, fun, *args, **params) 1550 tracers = map(top_trace.full_raise, args)-> 1551 outs = primitive.process(top_trace, fun, tracers, params) 1552 return map(full_lower, apply_todos(env_trace_todo(), outs)) /opt/conda/lib/python3.7/site-packages/jax/ in process(self, trace, fun, tracers, params) 1562 def process(self, trace, fun, tracers, params):-> 1563 return trace.process_call(self, fun, tracers, params) 1564 /opt/conda/lib/python3.7/site-packages/jax/ in process_call(self, primitive, f, tracers, params) 605 def process_call(self, primitive, f, tracers, params):–> 606 return primitive.impl(f, *tracers, **params) 607 process_map = process_call /opt/conda/lib/python3.7/site-packages/jax/interpreters/ in _xla_call_impl(failed resolving arguments) 592 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, β†’ 593 *unsafe_map(arg_spec, args)) 594 try: /opt/conda/lib/python3.7/site-packages/jax/ in memoized_fun(fun, *args) 261 else:–> 262 ans = call(fun, *args) 263 cache[key] = (ans, fun.stores) /opt/conda/lib/python3.7/site-packages/jax/interpreters/ in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs) 667 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final( β†’ 668 fun, abstract_args, pe.debug_info_final(fun, β€œjit”)) 669 if any(isinstance(c, core.Tracer) for c in consts): /opt/conda/lib/python3.7/site-packages/jax/interpreters/ in trace_to_jaxpr_final(fun, in_avals, debug_info) 1283 with core.new_sublevel():-> 1284 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) 1285 del fun, main /opt/conda/lib/python3.7/site-packages/jax/interpreters/ in trace_to_subjaxpr_dynamic(fun, main, in_avals) 1261 in_tracers = map(trace.new_arg, in_avals)-> 1262 ans = fun.call_wrapped(*in_tracers) 1263 out_tracers = map(trace.full_raise, ans) /opt/conda/lib/python3.7/site-packages/jax/ in call_wrapped(self, *args, **kwargs) 165 try:–> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: /opt/conda/lib/python3.7/site-packages/trax/optimizers/ in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng) 217 (loss, state), gradients = forward_and_backward_fn( β†’ 218 batch, weights, state, rng) 219 weights, slots, stats = optimizer.tree_update( /opt/conda/lib/python3.7/site-packages/jax/_src/ in reraise_with_filtered_traceback(*args, **kwargs) 182 try:–> 183 return fun(*args, **kwargs) 184 except Exception as e: /opt/conda/lib/python3.7/site-packages/jax/_src/ in value_and_grad_f(*args, **kwargs) 883 else:–> 884 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) 885 _check_scalar(ans) /opt/conda/lib/python3.7/site-packages/jax/_src/ in _vjp(fun, has_aux, *primals) 1964 flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)-> 1965 out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) 1966 out_tree, aux_tree = out_aux_trees() /opt/conda/lib/python3.7/site-packages/jax/interpreters/ in vjp(traceable, primals, has_aux) 115 else:–> 116 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True) 117 /opt/conda/lib/python3.7/site-packages/jax/interpreters/ in linearize(traceable, *primals, **kwargs) 100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)–> 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) 102 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) /opt/conda/lib/python3.7/site-packages/jax/interpreters/ in trace_to_jaxpr(fun, pvals, instantiate) 504 fun = trace_to_subjaxpr(fun, main, instantiate)–> 505 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) 506 assert not env /opt/conda/lib/python3.7/site-packages/jax/ in call_wrapped(self, *args, **kwargs) 165 try:–> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: /opt/conda/lib/python3.7/site-packages/trax/layers/ in pure_fn(self, x, weights, state, rng, use_cache) 605 raise LayerError(name, β€˜pure_fn’, β†’ 606 self._caller, signature(x), trace) from None 607 UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn): layer created in file […]/trax/supervised/, line 1033 layer input shapes: (ShapeDtype{shape:(256, 64), dtype:int32}, ShapeDtype{shape:(256, 64), dtype:int32}) File […]/trax/layers/, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True) LayerError: Exception passing through layer TripletLoss (in pure_fn): layer created in file […]/, line 4 layer input shapes: (ShapeDtype{shape:(256, 128), dtype:float32}, ShapeDtype{shape:(256, 128), dtype:float32}) File […]/trax/layers/, line 743, in forward raw_output = self._forward_fn(inputs) File […]/trax/layers/, line 784, in _forward return f(*xs) File […]/, line 32, in TripletLossFn if element > positive[iRow]: File […]/site-packages/jax/, line 535, in bool def bool(self): return self.aval._bool(self) File […]/site-packages/jax/, line 954, in error raise ConcretizationTypeError(arg, fname_context) jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool)>with<DynamicJaxprTrace(level=0/1)> The problem arose with the bool function. While tracing the function single_device_update_fn at /opt/conda/lib/python3.7/site-packages/trax/optimizers/ for jit, this concrete value was not available in Python because it depends on the values of the arguments β€˜weights_and_slots’ and β€˜batch’. See JAX Errors β€” JAX documentation The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception:

I excluded the second exception to not exceed the character limit.

Does anyone have an idea about how to fix this error?

1 Like

I solved the problem.

I found out that the error was caused by how I implemented the second mask in UNQ_C4. In my previous implementation, I was checking each element in the negative_zero_on_duplicate array for comparison with positive values. I changed the implementation such that I substracted each positive value from corresponding row of negative_zero_on_duplicate and then I compared the resulting array with 0. With the new implementation, the error disappeared.

Hi @ceyhunemreozturk,

Glad you were able to figure it out on your own.

Happy learning,

1 Like