Im trying to pass a dat point in a JAX grad function to determine a derivative. I have a function but when I index to a point in an array i get a series of index, shape or other errors. I am using the grad(function, array(index)) and tried chaining as grad(function(array(index)) but the command isn’t right. So what is the proper separation from function?
Hi @Allen_Susie
Are you doing W1 graded lab? You should see this guide on using grad() in Exercise 4 unless you deleted it.
dLdOmega = None(None)(None[None])