I’ve been trying to complete the function scaled_dot_product_attention(q, k, v, mask), but can’t get it right. I think the problem is somehow related with the value of dk, but I don’t know how to fix.
Here’s the error:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-67-00665b20febb> in <module>
1 # UNIT TEST
----> 2 scaled_dot_product_attention_test(scaled_dot_product_attention)
[snippet removed by mentor]
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
199 """Call target, and fall back on dispatchers if there is a TypeError."""
200 try:
--> 201 return target(*args, **kwargs)
202 except (TypeError, ValueError):
203 # Note: convert_to_eager_tensor currently raises a ValueError, not a
/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py in sqrt(x, name)
4901 A `tf.Tensor` of same size, type and sparsity as `x`.
4902 """
-> 4903 return gen_math_ops.sqrt(x, name)
4904
4905
/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py in sqrt(x, name)
10040 try:
10041 return sqrt_eager_fallback(
> 10042 x, name=name, ctx=_ctx)
10043 except _core._SymbolicException:
10044 pass # Add nodes to the TensorFlow graph.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py in sqrt_eager_fallback(x, name, ctx)
10063 _attrs = ("T", _attr_T)
10064 _result = _execute.execute(b"Sqrt", 1, inputs=_inputs_flat, attrs=_attrs,
> 10065 ctx=ctx, name=name)
10066 if _execute.must_record_gradient():
10067 _execute.record_gradient(
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: Value for attr 'T' of int32 is not in the list of allowed values: bfloat16, half, float, double, complex64, complex128
; NodeDef: {{node Sqrt}}; Op<name=Sqrt; signature=x:T -> y:T; attr=T:type,allowed=[DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128]> [Op:Sqrt]