Hi, I am not getting this part of the code. Can someone explain this?
for epoch in range(n_epochs):
# Dataloader returns the batches and the labels
for real, labels in tqdm(dataloader):
cur_batch_size = len(real)
# Flatten the batch of real images from the dataset
real = real.to(device)
# print("real",type(real),len(real),real.shape,real[1].shape)
one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
#print(one_hot_labels,type(one_hot_labels),one_hot_labels.shape)
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])
print(image_one_hot_labels,type(image_one_hot_labels),image_one_hot_labels.shape)
Why is image_one_hot_labels=one_hot_labels[:, :, None, None]?
Why is it none and can someone explain what is happening in this part?
And also here : mage_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])