Can somebody explain, why that is? I have the currently newest (gpu-)version of jaxlib, trax and tensorflow. Did they perhaps just change the data type compared to the version used in Coursera?
tensorflow==2.12.0
trax==1.4.1
jaxlib==0.4.6+cuda11.cudnn86 (i guess, it’s because i installed the gpu version?)
Since it works just fine in the Coursera notebooks, it’s not a big issue… Just curious
Hey @JonasK,
Welcome, and we are glad that you could be a part of our community
Looks like you just discovered the answer to your own question. All the notebooks are made keeping the Coursera environments in mind, and hence, it is recommended to run them in the Coursera environments only. When we try to run them in our local machines, due to difference in the versions of packages, we might get warnings, just like you got, and at times, we might even get errors, which you might get, if you tread along the path that you are following. I hope this resolves your query
We introduce jax.Array which is a unified array type that subsumes DeviceArray , ShardedDeviceArray , and GlobalDeviceArray types in JAX.…The jax.Array migration guide can help you migrate your codebase to jax.Array