W2A2: What does train_dataset.take(1) mean in MobileNet assignment

I am doing the MobileNet assignment and there is this code which fetches sample images from training set.

for images, labels in train_dataset.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

In the first line of the code what does train_dataset.take(1) mean? From help I could see it creates a new dataset with the number of count elements given. Here the count is 1 which means a new dataset with 1 image is created. Is my understanding right? If yes how are we plotting 9 images when we have only 1 in images dataset. I am confused how this works.

1 Like

Hi Yaswanth64,

You are correct in that 1 image is assigned to variable image and then to first_image. This first_image is then augmented 9 times. So you have a single image that is augmented and plotted, and this happens 9 times.

1 Like

I am not sure if I understood your statement correctly on augmentation. I think augmentation is about increasing the data from existing data (cropping, flipping etc.) but here train_dataset.take(1) we have taken single image but the plot shows different images and not the augmented one of images variable.

Below is the screenshot of the output. You can see that there are different images and not the same image augmented.

image

1 Like

Hi Yaswanth64,

Sorry, I thought you were referring to a different cell and because of this my answer was off.
The train_dataset.take(1) takes one batch of images and labels from train_dataset. If you want to see all the images in the batch (with batch size 32) you can use:

1 Like