Would apperciate some help on the grad function in JAX

Im trying to pass a dat point in a JAX grad function to determine a derivative. I have a function but when I index to a point in an array i get a series of index, shape or other errors. I am using the grad(function, array(index)) and tried chaining as grad(function(array(index)) but the command isn’t right. So what is the proper separation from function?

Hi @Allen_Susie

Are you doing W1 graded lab? You should see this guide on using grad() in Exercise 4 unless you deleted it.

dLdOmega = None(None)(None[None])