Training the conditional GAN -- Shapes of tensors explanation

Hi, I am not getting this part of the code. Can someone explain this?

for epoch in range(n_epochs):
    # Dataloader returns the batches and the labels
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real)
        # Flatten the batch of real images from the dataset
        real = real.to(device)
       # print("real",type(real),len(real),real.shape,real[1].shape)


        one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
        #print(one_hot_labels,type(one_hot_labels),one_hot_labels.shape)
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])
        print(image_one_hot_labels,type(image_one_hot_labels),image_one_hot_labels.shape)
        

Why is image_one_hot_labels=one_hot_labels[:, :, None, None]?

Why is it none and can someone explain what is happening in this part?
And also here : mage_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

2 Likes

Hey @A_MR,
Welcome to the community. The answer to it is pretty straightforward. Consider the below line of code:

one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)

I think this line is self-explanatory. We are getting the one-hot labels for all the images in our batch, where in 1 batch, we have cur_batch_size number of images. So, the shape of one_hot_labels currently is (cur_batch_size, n_classes).

Now, consider the next line of code:

image_one_hot_labels = one_hot_labels[:, :, None, None]

With the help of this line of code, we are adding dimensions to our image_one_hot_labels vector. Note that, we are just adding dimensions, we still haven’t defined the shape for these dimensions.

image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape, mnist_shape)

And finally, using the above line of code, we have defined the shape along these new dimensions, i.e., 28 along the 3rd dimension and 28 along the 4th dimension. And, we are filling the same values in these 28*28 = 784 cells, using the repeat function.

I hope this helps, feel free to ask if any queries :innocent:

2 Likes

why we add the 1,1 in the first 2 dimensions?

2 Likes

Hey @Ahmed_Emam1,
Welcome to the community. I am assuming we are clear on the fact that image_one_hot_labels are combined with the real and fake images that are fed to the discriminator.

Now, in the previous line of code, i.e.

image_one_hot_labels = one_hot_labels[:, :, None, None]

We have already defined image_one_hot_labels to have 4 dimensions, where first dimension represents the images in a batch, second dimension represents the channels in an image, third and fourth dimensions represent the height and width respectively. Also, we know that the labels are supposed to be the same for an image across it’s height and width. So, we simply need to repeat the labels across the height and the width, and we don’t need to repeat them in the dimensions representing the batch or the channels, and that’s why the 1’s in the first 2 dimensions.

I hope this helps, and if this doesn’t, I urge you to watch the lecture videos again. Prof. Andrew has clearly explained how we format the labels in the case of Conditional GANs when feeding them to the generator and to the discriminator.

Regards,
Elemento

2 Likes