# Improved Implementation of Softmax - Trouble Understanding the Logic

Week 2 - Multiclass Classification - Improved Implementation of Softmax

Hi all, I’m confused on this part of the lecture and the point of substituting the final layer activation function as linear instead of softmax. And then using softmax to compute the probabilities of that linear output.

Firstly, the logic given of avoiding “rounding errors” is confusing. I understand the first example given:

But when it’s extended to logistic regression, it’s confusing to me how it would make a difference. It looks like “a” is just being set as a variable equal to the sigmoid activation function? How would bypassing using “a” make any difference? It doesn’t seem like “a” is any sort of convoluted rearrangement of the formula like shown in the first example?

Hi @DeepInData great question

The substitution of softmax with a linear activation, followed by a separate softmax application, is a technique to improve the numerical stability of your model, especially when dealing with potentially large or small values. It’s not about changing the fundamental logic of your model but rather about ensuring that your computations are more robust and less prone to errors.

1 Like

Here’s a thread from Raymond that explains this in more detail.

1 Like

In the lecture, Andrew mentions all of this logic is to avoid ‘rounding errors’ - is this logic still right? Or does it relate more to the extremely large or small values in Raymond’s post?

Hello, @DeepInData, I think they are not mutually exclusive.

The code below compares the formulae after (`l1`) and before (`l2`) the simplification shown in my post.

``````for p in range(-4, 4):
_z = 10 ** p
print(f'\nz: {_z}')

for dtype in [tf.float64, tf.float32]:
z = tf.constant(_z, dtype=dtype)
y = tf.constant(0., dtype=dtype)

g.watch(z)
l1 = tf.maximum(0, z) - z * y + tf.math.log(1 + tf.math.exp(-tf.abs(z)))

g.watch(z)
p = tf.math.reciprocal(1 + tf.math.exp(-z))
l2 = -y * tf.math.log(p) - (1 - y) * tf.math.log(1 - p)

print(
f'{dtype} dl0_dz: {p - y: 26.20f}',
f'{dtype} dl1_dz: {dl1_dz: 26.20f} l1: {l1: 26.20f}',
f'{dtype} dl2_dz: {dl2_dz: 26.20f} l2: {l2: 26.20f}',
sep='\n',
)
``````
``````
z: 0.0001
<dtype: 'float64'> dl0_dz:     0.50002499999997918056
<dtype: 'float64'> dl1_dz:     0.50002499999997918056 l1:     0.69319718180994527312
<dtype: 'float64'> dl2_dz:     0.50002499999997918056 l2:     0.69319718180994538415
<dtype: 'float32'> dl0_dz:     0.50002503395080566406
<dtype: 'float32'> dl1_dz:     0.50002497434616088867 l1:     0.69319719076156616211
<dtype: 'float32'> dl2_dz:     0.50002509355545043945 l2:     0.69319725036621093750

z: 0.001
<dtype: 'float64'> dl0_dz:     0.50024999997916663741
<dtype: 'float64'> dl1_dz:     0.50024999997916674843 l1:     0.69364730555994014161
<dtype: 'float64'> dl2_dz:     0.50024999997916663741 l2:     0.69364730555994003058
<dtype: 'float32'> dl0_dz:     0.50024998188018798828
<dtype: 'float32'> dl1_dz:     0.50024998188018798828 l1:     0.69364732503890991211
<dtype: 'float32'> dl2_dz:     0.50024992227554321289 l2:     0.69364726543426513672

z: 0.01
<dtype: 'float64'> dl0_dz:     0.50249997916687494381
<dtype: 'float64'> dl1_dz:     0.50249997916687494381 l1:     0.69815968050786236798
<dtype: 'float64'> dl2_dz:     0.50249997916687483279 l2:     0.69815968050786225696
<dtype: 'float32'> dl0_dz:     0.50249999761581420898
<dtype: 'float32'> dl1_dz:     0.50249993801116943359 l1:     0.69815969467163085938
<dtype: 'float32'> dl2_dz:     0.50250011682510375977 l2:     0.69815969467163085938

z: 0.1
<dtype: 'float64'> dl0_dz:     0.52497918747894001257
<dtype: 'float64'> dl1_dz:     0.52497918747894001257 l1:     0.74439666007357085942
<dtype: 'float64'> dl2_dz:     0.52497918747894012359 l2:     0.74439666007357097044
<dtype: 'float32'> dl0_dz:     0.52497917413711547852
<dtype: 'float32'> dl1_dz:     0.52497923374176025391 l1:     0.74439668655395507812
<dtype: 'float32'> dl2_dz:     0.52497917413711547852 l2:     0.74439662694931030273

z: 1
<dtype: 'float64'> dl0_dz:     0.73105857863000489605
<dtype: 'float64'> dl1_dz:     0.73105857863000478503 l1:     1.31326168751822280889
<dtype: 'float64'> dl2_dz:     0.73105857863000500707 l2:     1.31326168751822280889
<dtype: 'float32'> dl0_dz:     0.73105859756469726562
<dtype: 'float32'> dl1_dz:     0.73105859756469726562 l1:     1.31326162815093994141
<dtype: 'float32'> dl2_dz:     0.73105865716934204102 l2:     1.31326174736022949219

z: 10
<dtype: 'float64'> dl0_dz:     0.99995460213129760962
<dtype: 'float64'> dl1_dz:     0.99995460213129760962 l1:    10.00004539889921773010
<dtype: 'float64'> dl2_dz:     0.99995460213226738944 l2:    10.00004539890018584458
<dtype: 'float32'> dl0_dz:     0.99995458126068115234
<dtype: 'float32'> dl1_dz:     0.99995458126068115234 l1:    10.00004577636718750000
<dtype: 'float32'> dl2_dz:     0.99949508905410766602 l2:     9.99958610534667968750

z: 100
<dtype: 'float64'> dl0_dz:     1.00000000000000000000
<dtype: 'float64'> dl1_dz:     1.00000000000000000000 l1:   100.00000000000000000000
<dtype: 'float64'> dl2_dz:                        inf l2:                        inf
<dtype: 'float32'> dl0_dz:     1.00000000000000000000
<dtype: 'float32'> dl1_dz:     1.00000000000000000000 l1:   100.00000000000000000000
<dtype: 'float32'> dl2_dz:                        nan l2:                        inf

z: 1000
<dtype: 'float64'> dl0_dz:     1.00000000000000000000
<dtype: 'float64'> dl1_dz:     1.00000000000000000000 l1:  1000.00000000000000000000
<dtype: 'float64'> dl2_dz:                        nan l2:                        inf
<dtype: 'float32'> dl0_dz:     1.00000000000000000000
<dtype: 'float32'> dl1_dz:     1.00000000000000000000 l1:  1000.00000000000000000000
<dtype: 'float32'> dl2_dz:                        nan l2:                        inf
``````

Note that 32-bit is more significant in rounding error than 64-bit. Sometimes, 32-bit is preferred over 64-bit as it saves 50% of memory use which is an important consideration in training large models.

For each `z`, check out the difference in the six versions of, for example, `dlx_dz` which affects the update of weights. Note that `dl0_dz`, which applies the gradient formula (`p - y`) directly, is also there for comparison.

Let’s see what you observe from the result, or modify the code for clearer result!

Cheers.