Different Data Types in Local Notebook

Hi there,

when running the notebooks locally, I always get following feedback from the unittests:

 Expected <class 'jaxlib.xla_extension.DeviceArrayBase'>.
Got <class 'jaxlib.xla_extension.ArrayImpl'>.

Can somebody explain, why that is? I have the currently newest (gpu-)version of jaxlib, trax and tensorflow. Did they perhaps just change the data type compared to the version used in Coursera?

tensorflow==2.12.0
trax==1.4.1
jaxlib==0.4.6+cuda11.cudnn86 (i guess, it’s because i installed the gpu version?)

Since it works just fine in the Coursera notebooks, it’s not a big issue… Just curious :slight_smile:

Thanks and best regards :slight_smile:

Hey @JonasK,
Welcome, and we are glad that you could be a part of our community :partying_face:

Looks like you just discovered the answer to your own question. All the notebooks are made keeping the Coursera environments in mind, and hence, it is recommended to run them in the Coursera environments only. When we try to run them in our local machines, due to difference in the versions of packages, we might get warnings, just like you got, and at times, we might even get errors, which you might get, if you tread along the path that you are following. I hope this resolves your query :nerd_face:

Cheers,
Elemento

1 Like

thanks for the quick response! :slight_smile:

Whenever you run into issues like that it’s a good idea to look for release notes. For example…

(Change log — JAX documentation)

We introduce jax.Array which is a unified array type that subsumes DeviceArray , ShardedDeviceArray , and GlobalDeviceArray types in JAX.The jax.Array migration guide can help you migrate your codebase to jax.Array

https://jax.readthedocs.io/en/latest/jax_array_migration.html

I’m confident if you compare jax versions, you’ll find the Coursera runtime is pre-0.4.1 (and yours is post-)

1 Like