C4W2 UNQ_C8 training_loop error

All the previous tests passed but I am not sure why I cannot split the data according to (n_batch, seqlen, n_heads, d_head). I tried refreshing the workspace since I thought it was the problem with the dataset but to no avail. The error callback trace are as follows:

---------------------------------------------------------------------------
LayerError                                Traceback (most recent call last)
<ipython-input-45-d1e9ffd2f1b1> in <module>
      1 # UNIT TEST
      2 # test training_loop
----> 3 w2_tests.test_training_loop(training_loop, TransformerLM)

~/work/w2_tests.py in test_training_loop(target, TransformerLM)
    813         os.remove("~/model/model.pkl.gz")
    814 
--> 815     output_loop = target(TransformerLM, my_gen(), my_gen())
    816 
    817     try:

<ipython-input-44-9cb08261c0b0> in training_loop(TransformerLM, train_gen, eval_gen, output_dir)
     40                          train_task,
     41                          eval_tasks=[eval_task],
---> 42                          output_dir=output_dir)
     43 
     44     return loop

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    246     if not use_memory_efficient_trainer:
    247       if _is_uninitialized(self._model):
--> 248         self._model.init(self._batch_signature)
    249       self._eval_model.rng = self.new_rng()
    250       if _is_uninitialized(self._eval_model):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
    309       name, trace = self._name, _short_traceback(skip=3)
    310       raise LayerError(name, 'init', self._caller,
--> 311                        input_signature, trace) from None
    312 
    313   def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-40-3c2b89f9c7cb>, line 63
  layer input shapes: (ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64}, ShapeDtype{shape:(1, 1250), dtype:int64})

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-37-5b7ff822d20f>, line 54
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
  layer created in file [...]/<ipython-input-37-5b7ff822d20f>, line 54
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Parallel (in init):
  layer created in file [...]/<ipython-input-37-5b7ff822d20f>, line 54
  layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})

  File [...]/trax/layers/combinators.py, line 226, in init_weights_and_state
    in zip(self.sublayers, sublayer_signatures)]

  File [...]/trax/layers/combinators.py, line 225, in <listcomp>
    for layer, signature

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-37-5b7ff822d20f>, line 54
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-34-c924fc487e9d>, line 47
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Branch (in init):
  layer created in file [...]/<ipython-input-34-c924fc487e9d>, line 39
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    sublayer.init(inputs, use_cache=True))

LayerError: Exception passing through layer Parallel (in init):
  layer created in file [...]/<ipython-input-34-c924fc487e9d>, line 39
  layer input shapes: (ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32}, ShapeDtype{shape:(1, 1250, 4), dtype:float32})

  File [...]/trax/layers/combinators.py, line 226, in init_weights_and_state
    in zip(self.sublayers, sublayer_signatures)]

  File [...]/trax/layers/combinators.py, line 225, in <listcomp>
    for layer, signature

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-34-c924fc487e9d>, line 39
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/trax/layers/combinators.py, line 106, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer AttnHeads (in _forward_abstract):
  layer created in file [...]/<ipython-input-34-c924fc487e9d>, line 32
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/jax/interpreters/partial_eval.py, line 419, in abstract_eval_fun
    lu.wrap_init(fun, params), avals, debug_info)

  File [...]/jax/interpreters/partial_eval.py, line 1510, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/partial_eval.py, line 1520, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer AttnHeads (in pure_fn):
  layer created in file [...]/<ipython-input-34-c924fc487e9d>, line 32
  layer input shapes: ShapeDtype{shape:(1, 1250, 4), dtype:float32}

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/<ipython-input-25-84a31e493485>, line 29, in compute_attention_heads
    x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))

  File [...]/_src/numpy/lax_numpy.py, line 1711, in reshape
    return a.reshape(newshape, order=order)  # forward to method for ndarrays

  File [...]/_src/numpy/lax_numpy.py, line 1731, in _reshape
    return lax.reshape(a, newshape, None)

  File [...]/_src/lax/lax.py, line 832, in reshape
    dimensions=None if dimensions is None or same_dims else tuple(dimensions))

  File [...]/site-packages/jax/core.py, line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1317, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2274, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),

  File [...]/_src/lax/lax.py, line 4096, in _reshape_shape_rule
    if not core.same_shape_sizes(np.shape(operand), new_sizes):

  File [...]/site-packages/jax/core.py, line 1412, in same_shape_sizes
    return 1 == divide_shape_sizes(s1, s2)

  File [...]/site-packages/jax/core.py, line 1409, in divide_shape_sizes
    return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])

  File [...]/site-packages/jax/core.py, line 1324, in divide_shape_sizes
    raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")

jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (1, 1250, 4) and (1, 1250, 2, 4)

Please help!

1 Like

Nevermind, I have solved this problem. It originates from UNQ_C5, where I mistaken set the d_head as d_feature for both the compute_attention_heads_closure and compute_attention_output_closure functions. It has been resolved.

6 Likes

Thanks! I had the same problem.