Plotting handwritten digits (C2_W2)

Hey hey,

I am playing around with the code to build a softmax neural network to recognise handwritten digits (as in the Assignment from week 2).

This is the code to visualize 64 random handwritten digits from the training dataset.

m,n = X.shape
fig, ax = plt.subplots(8,8, figsize = (5,5))
fig.tight_layout(pad=0.15,rect=[0,.03, 1, 0.91])

for i, axes in enumerate(ax.flat):
random = np.random.randint(m)
X_random_reshaped = X[random].reshape((20,20)).T
axes.imshow(X_random_reshaped, cmap = “gray”)
axes.set_title(Y[random,0])
axes.set_axis_off()

fig.suptitle(“Label, image”, fontsize = 14)

I don’t understand what enumerate(ax.flat) is? What exactly are the variables i and axes getting iterated by?

Thanks! :slight_smile:

Cheers
Nadi

Hello, Nadi @nadidixit,

enumerate effectively numbers the elements of ax.flat starting from (default) 0. In other words, if you print i out in the loop, you will see 0, 1, 2, …

ax is an array of objects used for plotting, so what the loop does is to get one such object from the array at a time together with its own unique number assigned by enumerate. Unfortunately this code did not actually use i, so it is not the best example to discuss how i may be of help.

Nadi, when I learn, I often find it helpful to print the variables concerned because it may give me an immediate answer or some leads to further look into. As learner, I would recommend the same to you. :wink:

Cheers,
Raymond

1 Like

Thanks Raymond!