Evaluate function L_of_omega for each of the elements of the array omega_array and pass the result into the corresponding element of the array L_array with the function .at[<index>].set(<value>) .
We had already calculated L_of_omega in the previous step in exercise 2.
I did not understand what we are supposed to do here in Exercise 3.
L = None(None[None])
L_array = L_array.at[None].set(None)
and which function from jax.numpy needs to be used? grad?
You are correct that L_of_omega function is already created in the previous step. However, this function can only accept one (1) value of omega at a time. If you try to pass an array of values (multiple values of omega at the same time), the function will fail.
The exercise is to develop a function that accepts an array of values. This can be done by using the existing L_of_omega inside a “for loop” and loop through each value of the omega_array.
jax library will be used in Exercise 4.
I hope this gives you a head start. Try it, then let me know if you need further help. Thank you!
The issue is you’re calling L_of_omega but you’re only passing in L_array and the function takes three parameters, including pA and pB. Also, you don’t want to pass in L_array, you want to pass in omega_array[i]