Jupyter - problems installing Jax

Hi all.

I’m getting errors when trying to run the jax cells in the C2 W1 Differentiation notebook.

I’m running Anaconda on Windows 10.

Anyone have any fixes?

Thanks in hope,

1 Like

When I tried running the NLP codes locally I had to pay very close attention to versions of the many codependent packages. TF, jax, python, numpy…and more. All need to be at coherent releases. I had better luck controlling installation with the smaller footprint of conda than with Anaconda, which is inherently more complex due to its size. Have you searched this forum already?

1 Like


Thanks for your reply.

I’ve searched the forum for Jax, but found nothing.

I’ll look into Conda vs Anaconda.


1 Like

You can also help us help you by sharing the errors you’re seeing. Runtime, right? Not jax install

1 Like

Ok, will do tomorrow … not in front of the PC right now.

I tried several approaches before Jax seemingly installed ok, but still getting errors upon running the cell.

1 Like

Upon executing the following cell…

from jax import grad, vmap
#from jax import *
import jax.numpy as jnp

I get the following error…

ModuleNotFoundError Traceback (most recent call last)
Cell In[31], line 1
----> 1 from jax import grad, vmap
2 #from jax import *
3 import jax.numpy as jnp

File ~\anaconda3\lib\site-packages\jax_init_.py:21
18 del _os
20 # flake8: noqa: F401
—> 21 from .config import config
22 from .api import (
23 ad, # TODO(phawkins): update users to avoid this.
24 argnums_partial, # TODO(phawkins): update Haiku to not use this.
87 xla_computation,
88 )
89 from .experimental.maps import soft_pmap

File ~\anaconda3\lib\site-packages\jax\config.py:19
17 import threading
18 from typing import Optional
—> 19 from jax import lib
21 def bool_env(varname: str, default: bool) → bool:
22 “”“Read an environment variable and interpret it as a boolean.
24 True values are (case insensitive): ‘y’, ‘yes’, ‘t’, ‘true’, ‘on’, and ‘1’;
30 Raises: ValueError if the environment variable is anything else.
31 “””

File ~\anaconda3\lib\site-packages\jax\lib_init_.py:23
1 # Copyright 2018 Google LLC
2 #
3 # Licensed under the Apache License, Version 2.0 (the “License”);
15 # This module is largely a wrapper around jaxlib that performs version
16 # checking on import.
18 all = [
19 ‘cuda_prng’, ‘cusolver’, ‘rocsolver’, ‘jaxlib’, ‘lapack’,
20 ‘pocketfft’, ‘pytree’, ‘tpu_client’, ‘version’, ‘xla_client’
21 ]
—> 23 import jaxlib
25 # Must be kept in sync with the jaxlib version in build/test-requirements.txt
26 _minimum_jaxlib_version = (0, 1, 60)

ModuleNotFoundError: No module named ‘jaxlib’

1 Like

I tried to install Jax as follows…

pip install jax

… and got the following response…

Requirement already satisfied: jax in c:\users\kevdo\anaconda3\lib\site-packages (0.2.10)
Requirement already satisfied: numpy>=1.12 in c:\users\kevdo\anaconda3\lib\site-packages (from jax) (1.23.5)
Requirement already satisfied: absl-py in c:\users\kevdo\anaconda3\lib\site-packages (from jax) (1.4.0)
Requirement already satisfied: opt-einsum in c:\users\kevdo\anaconda3\lib\site-packages (from jax) (3.3.0)
Note: you may need to restart the kernel to use updated packages.

1 Like

After restarting the kernel, I tried running the Jax import again, but got the same errors.

Also tried installing from Conda Forge ( Jax :: Anaconda.org

conda install

To install this package run one of the following:
conda install -c conda-forge jax
conda install -c "conda-forge/label/broken" jax
conda install -c "conda-forge/label/cf202003" jax

… however, none of the above 3 methods work (either fail to resolve the environment, or seemingly hang).

Final attempt was google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more (github.com)

pip install --upgrade pip
pip install --upgrade “jax[cpu]”

… after which I restarted the kernel, and then tried re-running the Jax import, only to get the same errors as previously described!

I’ve therefore given up on Jax, and moving onto Wk2!

1 Like

Some interesting tidbits on that page, such as…

Windows users can use JAX on CPU and GPU via the Windows Subsystem for Linux. In addition, there is some initial community-driven native Windows support, but since it is still somewhat immature, there are no official binary releases and it must be built from source for Windows.


These pip installations do not work with Windows, and may fail silently

1 Like

Ah, well spotted!

I have WSL installed, so I’ll give it a go.

Thanks :+1:

1 Like

I’ve got Jupyter working via WSL and can open the differentiation notebook.

I installed Jax from the linux cmd-line using:

pip install --upgrade pip
pip install --upgrade “jax[cpu]”

Jax now works in the Jupyter notebook.


Oh yes!

I am a linux guy so I do not have those issues. But congrats on getting it working on Windows!

1 Like

If you find a way to get it (and trax!) to play nice on macOS, let me know!

1 Like

I had similar issues getting JAX to run on macOS, Ventura 13.3.1 while attempting to complete the ungraded lab in week 1.

I had initially installed it using pip ignoring the warnings about the install path. When trying to import JAX into the workbook (I’m running a downloaded copy on my local maching), it had issues with the version of jaxlib, similar to what @K_Docherty posted. I focused on the pip warning that I ignored on the initial installation of JAX:

"Defaulting to user installation because normal site-packages is not writeable"

Trying to upgrade pip, I received the same warning. I fixed both by using sudo with the -H flag:

sudo -H pip install --upgrade pip
sudo -H pip install -U "jax[cpu]"

This has resolved my issues.

1 Like