Calculating gradient of softmax function

for multi-class classification deep neural networks, we use softmax as an output layer activation function.
ZL = WLA(L-1) + bL
AL = softmax(ZL)
cost = J(W1, b1, W2, b2,…WL, bL) = J = (-1/m)
“dAL” = dJ/dAL = -Y/AL
“dZL” = dJ/dWL = (dJ/dAL)
(dAL/dZL) = “dAL”*gradient(softmax(ZL)) = (-Y/AL)*grad(softmax(ZL))
→ gradient(softmax(ZL)) is derivative of softmax function with respect toZL
→ from internet search I found out that for a vector z of shape (n, 1), derivative of softmax w.r.t z will be a (n, n) matrix.
→ so, how to product(element-wise or dot product) of derivative with “dAL” will lead to the result presented in the lectures which was
“dZL” = (-Y/AL)*grad(softmax(ZL)) = AL - Y.
How did we get AL - Y as “dZL”


Hi, @Vikasjaat.

Softmax has n inputs and n outputs. If you compute the partial derivatives of each output with respect to each input you end up with an (n, n) matrix.

I think the expression you got is a matrix multiplication of a (1, n) vector with the (n, n) matrix of partial derivatives. The dot product of the vector with each column of the matrix would correspond to this term (source):

Intuitively, the summation makes sense, since a change to any given input affects all of the outputs.

I hope I got it right :slight_smile:

P.S. If you’re going to implement it, you may have to reshape your vectors, depending on the representation you choose.


I decided to post a more complete derivation if anyone need extra clarification in the future. By the way, @nramon thx for the solution it helps me a lot!

Derivative of softmax activation:
Case 1:

Case 2:

Conclusion on derivative of Softmax:

Finally we utilize the derivation of softmax activation on derivative of the loss function:


Awesome, @hongjiaherng. Thank you for sharing this :slight_smile:

Sorry for a dumb question here. As long as i understand, yi for case when we are predicting a class. Ex: dog, cat, cow. so for dog yi which is actual output will be [1,0,0]. So when computing loss or gradient, in the summation equation, why do we even consider a case of i not equal to j. As we know for such case yi is 0

I admire your optimism in adding a reply to a thread that has been cold for three years.

Both the True (1) and False (0) cases are important because the cost should be high for making any incorrect prediction.