Why use `jnp.matmul` and `jnp.swapaxes` instead of `@` and `.T`, respectively?

The Week 2 assignment says the following:

Also take into account that the real tensors are far more complex than the toy ones you just played with. Because of this avoid using shortened operations such as @ for dot product or .T for transposing. Use jnp.matmul() and jnp.swapaxes() instead.

However, using the jnp functions doesn’t improve execution time in the following experiment:

Is this experiment misleading in some way? Is there a reason other than execution time to favor the jnp functions? Or is the advice to use those functions misguided?

It seems to me those few instructions written here are very basic and I would guess because they perform basic operations there would not be much differrence in speed of accessing memory.

But it has been proven from much more complicated operations that jax is faster than numpy and specifically when training ML jax based models.

In the screenshot I show an example of multiplying a 10k x 10k matrix, and it takes ~14.5s either way.

I believe that Jax can be faster than NumPy, especially when it uses a GPU. I’m pretty sure that’s not what’s at issue here – surely @ with Jax arrays will still use Jax rather than reverting to NumPy somehow. I suspect it’s not actually doing anything different from jnp.matmul, so that we might as well use the notationally simpler operator, but I’m not sure.