Help with ex7 (ass1)

hi,
i can’t get my bug in ex7 (sampling_decode)
it seems a problem in the NMTAttn running:


UnfilteredStackTrace Traceback (most recent call last)
in
2 # Run it several times with each setting and see how often the output changes.
----> 3 sampling_decode(“I love languages.”, NMTAttn=model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)

in sampling_decode(input_sentence, NMTAttn, temperature, vocab_file, vocab_dir, next_symbol, tokenize, detokenize)
38 # update the current output token by getting the index of the next word (hint: use next_symbol)
—> 39 cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
40 # append the current output token to the list of output tokens

in next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
30 # get the model prediction
—> 31 output, _ = NMTAttn((input_tokens,padded_with_batch))
32 # get log probabilities from the last token output

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in call(self, x, weights, state, rng)
196 state = self.state
→ 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state

/opt/conda/lib/python3.7/site-packages/trax/layers/acceleration.py in pure_fn(self, x, weights, state, rng, use_cache)
76 if remainder == 0: # If yes, run the accelerated sublayer.pure_fn.
—> 77 return self._jit_pure_fn(x, weights, state, rng)
78 # If not, pad first.

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
→ 162 return fun(*args, **kwargs)
163 except Exception as e:

/opt/conda/lib/python3.7/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
418 device=device, backend=backend, name=flat_fun.name,
→ 419 donated_invars=donated_invars, inline=inline)
420 out_pytree_def = out_tree()

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1631 def bind(self, fun, *args, **params):
→ 1632 return call_bind(self, fun, *args, **params)
1633

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1622 tracers = map(top_trace.full_raise, args)
→ 1623 outs = primitive.process(top_trace, fun, tracers, params)
1624 return map(full_lower, apply_todos(env_trace_todo(), outs))

/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1634 def process(self, trace, fun, tracers, params):
→ 1635 return trace.process_call(self, fun, tracers, params)
1636

/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
626 def process_call(self, primitive, f, tracers, params):
→ 627 return primitive.impl(f, *tracers, **params)
628 process_map = process_call

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(failed resolving arguments)
687 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
→ 688 *unsafe_map(arg_spec, args))
689 try:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
262 else:
→ 263 ans = call(fun, *args)
264 cache[key] = (ans, fun.stores)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable_uncached(fun, device, backend, name, donated_invars, *arg_specs)
759 return lower_xla_callable(fun, device, backend, name, donated_invars,
→ 760 *arg_specs).compile().unsafe_call
761

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
771 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
→ 772 fun, abstract_args, pe.debug_info_final(fun, “jit”))
773 if any(isinstance(c, core.Tracer) for c in consts):

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
1541 with core.new_sublevel():
→ 1542 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1543 del fun, main

/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1519 in_tracers = map(trace.new_arg, in_avals)
→ 1520 ans = fun.call_wrapped(*in_tracers)
1521 out_tracers = map(trace.full_raise, ans)

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
→ 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
605 raise LayerError(name, ‘pure_fn’,
→ 606 self._caller, signature(x), trace) from None
607

UnfilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file […]/, line 65
layer input shapes: (ShapeDtype{shape:(1, 5), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})

File […]/trax/layers/base.py, line 707, in setattr
super().setattr(attr, value)

File […]/trax/layers/base.py, line 454, in weights
f’Number of weight elements ({len(weights)}) does not equal the ’

ValueError: Number of weight elements (1) does not equal the number of sublayers (9) in: Serial_in2_out2[
Select[0,1,0,1]_in2_out4

Parallel_in2_out2[
Serial[

  Embedding_33300_1024
  LSTM_1024

  LSTM_1024
]

Serial[
  Serial[

    ShiftRight(1)
  ]

  Embedding_33300_1024
  LSTM_1024

]

]

PrepareAttentionInput_in3_out4
Serial_in4_out2[

Branch_in4_out3[
  None

  Serial_in4_out2[
    _in4_out4

    Serial_in4_out2[
      Parallel_in3_out3[

        Dense_1024
        Dense_1024

        Dense_1024
      ]

      PureAttention_in4_out2
      Dense_1024

    ]
    _in2_out2

  ]
]

Add_in2

]

Select[0,2]_in3_out2
LSTM_1024

LSTM_1024
Dense_33300

LogSoftmax
].

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

LayerError Traceback (most recent call last)
in
1 # Test the function above. Try varying the temperature setting with values from 0 to 1.
2 # Run it several times with each setting and see how often the output changes.
----> 3 sampling_decode(“I love languages.”, NMTAttn=model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)

in sampling_decode(input_sentence, NMTAttn, temperature, vocab_file, vocab_dir, next_symbol, tokenize, detokenize)
37
38 # update the current output token by getting the index of the next word (hint: use next_symbol)
—> 39 cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
40 # append the current output token to the list of output tokens
41 cur_output_tokens.append(cur_output)

in next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
29 padded_with_batch = np.array(padded).reshape((1,int(padded_length)))
30 # get the model prediction
—> 31 output, _ = NMTAttn((input_tokens,padded_with_batch))
32 # get log probabilities from the last token output
33 log_probs = output[0,-1,:]

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in call(self, x, weights, state, rng)
195 self.state = state # Needed if the model wasn’t fully initialized.
196 state = self.state
→ 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
199 return outputs

/opt/conda/lib/python3.7/site-packages/trax/layers/acceleration.py in pure_fn(self, x, weights, state, rng, use_cache)
75 remainder = x.shape[0] % self._n_devices
76 if remainder == 0: # If yes, run the accelerated sublayer.pure_fn.
—> 77 return self._jit_pure_fn(x, weights, state, rng)
78 # If not, pad first.
79 def pad(z):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
604 name, trace = self._name, _short_traceback(skip=3)
605 raise LayerError(name, ‘pure_fn’,
→ 606 self._caller, signature(x), trace) from None
607
608 def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file […]/, line 65
layer input shapes: (ShapeDtype{shape:(1, 5), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})

File […]/trax/layers/base.py, line 707, in setattr
super().setattr(attr, value)

File […]/trax/layers/base.py, line 454, in weights
f’Number of weight elements ({len(weights)}) does not equal the ’

ValueError: Number of weight elements (1) does not equal the number of sublayers (9) in: Serial_in2_out2[
Select[0,1,0,1]_in2_out4

Parallel_in2_out2[
Serial[

  Embedding_33300_1024
  LSTM_1024

  LSTM_1024
]

Serial[
  Serial[

    ShiftRight(1)
  ]

  Embedding_33300_1024
  LSTM_1024

]

]

PrepareAttentionInput_in3_out4
Serial_in4_out2[

Branch_in4_out3[
  None

  Serial_in4_out2[
    _in4_out4

    Serial_in4_out2[
      Parallel_in3_out3[

        Dense_1024
        Dense_1024

        Dense_1024
      ]

      PureAttention_in4_out2
      Dense_1024

    ]
    _in2_out2

  ]
]

Add_in2

]

Select[0,2]_in3_out2
LSTM_1024

LSTM_1024
Dense_33300

LogSoftmax
].

UNIT TEST

solved by restart kernel

2 Likes