C5_W4_A1 scaled_dot_product_attention dimension error

Hi,
When executing scaled_dot_product_attention_test(), there is the error below and cannot figure out what’s the root cause. Is the problem on the Multiply function or something else? Thanks.

InvalidArgumentError Traceback (most recent call last)
in
1 # UNIT TEST
----> 2 scaled_dot_product_attention_test(scaled_dot_product_attention)

~/work/W4A1/public_tests.py in scaled_dot_product_attention_test(target)
69
70 mask = np.array([[[1, 1, 0, 1], [1, 1, 0, 1], [1, 1, 0, 1]]])
—> 71 attention, weights = target(q, k, v, mask)
72
73 assert np.allclose(weights, [[0.30719590187072754, 0.5064803957939148, 0.0, 0.18632373213768005],

in scaled_dot_product_attention(q, k, v, mask)
28 # add the mask to the scaled tensor.
29 if mask is not None: # Don’t replace this None
—> 30 scaled_attention_logits += tf.matmul(1. - mask, -1e9)
31
32 # softmax is normalized on the last axis (seq_len_k) so that the scores

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
199 “”“Call target, and fall back on dispatchers if there is a TypeError.”""
200 try:
→ 201 return target(*args, **kwargs)
202 except (TypeError, ValueError):
203 # Note: convert_to_eager_tensor currently raises a ValueError, not a

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py in matmul(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, name)
3275 adjoint_b = True
3276 return gen_math_ops.batch_mat_mul_v2(
→ 3277 a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
3278
3279 # Neither matmul nor sparse_matmul support adjoint, so we conjugate

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py in batch_mat_mul_v2(x, y, adj_x, adj_y, name)
1517 return _result
1518 except _core._NotOkStatusException as e:
→ 1519 _ops.raise_from_not_ok_status(e, name)
1520 except _core._FallbackException:
1521 pass

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
6860 message = e.message + (" name: " + name if name is not None else “”)
6861 # pylint: disable=protected-access
→ 6862 six.raise_from(core._status_to_exception(e.code, message), None)
6863 # pylint: enable=protected-access
6864

/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: In[1] ndims must be >= 2: 0 [Op:BatchMatMulV2]

Try using the multiplication operator ‘*’ instead of tf.matmul().

Hi TMosh,
It works now. Thanks for your advise~