I am doing the MobileNet
assignment and there is this code which fetches sample images from training set.
for images, labels in train_dataset.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
In the first line of the code what does train_dataset.take(1)
mean? From help I could see it creates a new dataset with the number of count elements given. Here the count is 1 which means a new dataset with 1 image is created. Is my understanding right? If yes how are we plotting 9 images when we have only 1 in images
dataset. I am confused how this works.
1 Like
Hi Yaswanth64,
You are correct in that 1 image is assigned to variable image and then to first_image. This first_image is then augmented 9 times. So you have a single image that is augmented and plotted, and this happens 9 times.
1 Like
I am not sure if I understood your statement correctly on augmentation. I think augmentation is about increasing the data from existing data (cropping, flipping etc.) but here train_dataset.take(1)
we have taken single image but the plot shows different images and not the augmented one of images
variable.
Below is the screenshot of the output. You can see that there are different images and not the same image augmented.
1 Like
Hi Yaswanth64,
Sorry, I thought you were referring to a different cell and because of this my answer was off.
The train_dataset.take(1) takes one batch of images and labels from train_dataset. If you want to see all the images in the batch (with batch size 32) you can use:
1 Like