Quite Confusing

  1. image_one_hot_labels = one_hot_labels[:, :, None, None]
    why do we need the line 1?

  2. image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

Here, again what is the purpose of line 2?

Hoping for your kind support

This line adds a couple of additional dimensions to get image_one_hot_labels to the shape we want for the next line.
The ‘:’ in [:,:, None, None] means keep everything in that dimension that was there before, and the None’s for the last two dimensions say that we want to add two new dimensions there (size 1 for now)

This line takes our one_hot_labels and repeats them so we have one copy for each pixel of mnist. You can see examples and a more detailed description in the repeat documentation.

To see for yourself how this works, try adding a few test lines to see what the shapes and results look like. For example, you could add a cell with something like these lines:

labels=torch.Tensor([1, 2, 7]).long()
print(f"labels: {labels.shape}, {labels}")
one_hot_labels = get_one_hot_labels(labels, 10)
print(f"one_hot_labels: {one_hot_labels.shape}, {one_hot_labels}")
image_one_hot_labels = one_hot_labels[:, :, None, None]
print(f"image_one_hot_labels: {image_one_hot_labels.shape}, {image_one_hot_labels[0]} ...")
image_one_hot_labels = image_one_hot_labels.repeat(1, 1, 2, 2)
print(f"image_one_hot_labels after repeat: {image_one_hot_labels.shape}, {image_one_hot_labels[0]} ...")