Ok, the fundamental problem is that torch.where
operates “elementwise” not “row-wise”, which is what we really need there. It turns out that because the second dimension is 1 in the first test case, it at least doesn’t throw an error, but it ends up interpreting it in a weird way and “broadcasting” to end up with a square result.
Here’s a test cell which demonstrates how this works with a small example:
# Play cell to understand torch.where
real = torch.ones(3,1) * 0.5
fake = torch.ones(3,1) * -0.75
fake_mask = torch.rand(len(real)) > 0.5
print(f"fake_mask.shape {fake_mask.shape}")
print(f"fake_mask = {fake_mask}")
target_images = torch.where(fake_mask, fake, real)
print(f"target_images.shape {target_images.shape}")
print(f"target_images {target_images}")
# Now try it with a 2D tensor of the correct shape
fake_mask = torch.reshape(fake_mask, (3,1))
print(f"fake_mask.shape {fake_mask.shape}")
print(f"fake_mask = {fake_mask}")
target_images = torch.where(fake_mask, fake, real)
print(f"target_images.shape {target_images.shape}")
print(f"target_images {target_images}")
# Now try non-singleton second dimension
real = torch.ones(3,5) * 0.5
fake = torch.ones(3,5) * -0.75
# This now throws an error if you use the 1D mask
# target_images = torch.where(fake_mask, fake, real)
# print(f"target_images.shape {target_images.shape}")
# print(f"target_images {target_images}")
# Now try it with a 2D tensor of the correct shape
fake_mask = torch.reshape(fake_mask, (3,1))
print(f"fake_mask.shape {fake_mask.shape}")
print(f"fake_mask = {fake_mask}")
target_images = torch.where(fake_mask, fake, real)
print(f"target_images.shape {target_images.shape}")
print(f"target_images {target_images}")
Running the above gives this result:
fake_mask.shape torch.Size([3])
fake_mask = tensor([False, True, False])
target_images.shape torch.Size([3, 3])
target_images tensor([[ 0.5000, -0.7500, 0.5000],
[ 0.5000, -0.7500, 0.5000],
[ 0.5000, -0.7500, 0.5000]])
fake_mask.shape torch.Size([3, 1])
fake_mask = tensor([[False],
[ True],
[False]])
target_images.shape torch.Size([3, 1])
target_images tensor([[ 0.5000],
[-0.7500],
[ 0.5000]])
fake_mask.shape torch.Size([3, 1])
fake_mask = tensor([[False],
[ True],
[False]])
target_images.shape torch.Size([3, 5])
target_images tensor([[ 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[-0.7500, -0.7500, -0.7500, -0.7500, -0.7500],
[ 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])
So what that last section shows is that you can get this to work with torch.where
, provided that you first “reshape” the computed mask to a 2D column tensor. Then broadcasting will give you the correct result, even in the case that the input does not have a trivial second dimension. If you just use the default 1D Boolean mask that you get, then it fails.