Here’s my function to create a “telltale” 4D tensor:
# routine to generate a telltale 4D tensor to play with
def testarray(shape):
(d1,d2,d3,d4) = shape
A = torch.zeros(*shape, dtype = torch.int32)
for ii1 in range(d1):
for ii2 in range(d2):
for ii3 in range(d3):
for ii4 in range(d4):
A[ii1,ii2,ii3,ii4] = ii1 * 1000 + ii2 * 100 + ii3 * 10 + ii4
return A
So the value in each position of the tensor shows the index values of its position in the tensor with each dimension in order. That is to say A[1,2,3,4] = 1234. Of course this is only going to be understandable if all the dimensions are single digit size.
Now let’s see what happens when we use the two “flatten” methods that we see above.
In the first case, let’s create a sample of shape (3, 2, 2, 3). So there are 3 samples, each of which is a 2 x 2 x 3 tensor. You can think of it as an image of shape 2 x 2 with 3 RGB pixel values, but the pixels are not normal 0 - 255 values.
sample3 = testarray([3, 2, 2, 3])
print(f"sample3 =\n{sample3}")
sample3 =
tensor([[[[ 0, 1, 2],
[ 10, 11, 12]],
[[ 100, 101, 102],
[ 110, 111, 112]]],
[[[1000, 1001, 1002],
[1010, 1011, 1012]],
[[1100, 1101, 1102],
[1110, 1111, 1112]]],
[[[2000, 2001, 2002],
[2010, 2011, 2012]],
[[2100, 2101, 2102],
[2110, 2111, 2112]]]], dtype=torch.int32)
Now apply the first method of flattening using the view() method on the tensor:
sampleView = sample3.view(-1, 2*2*3)
print(f"sampleView =\n{sampleView}")
print(f"sampleView.shape =\n{sampleView.shape}")
sampleView =
tensor([[ 0, 1, 2, 10, 11, 12, 100, 101, 102, 110, 111, 112],
[1000, 1001, 1002, 1010, 1011, 1012, 1100, 1101, 1102, 1110, 1111, 1112],
[2000, 2001, 2002, 2010, 2011, 2012, 2100, 2101, 2102, 2110, 2111, 2112]],
dtype=torch.int32)
sampleView.shape =
torch.Size([3, 12])
So the output is a 3 x 12 2D tensor. There are 3 rows and you can see that the first dimension of each entry in the row is the index of that row: 0 in the first row, 1 in the second row and 2 in the third row. Within each row, you can see that the flattening happens in reverse order by dimensions.
Now let’s try the other method using the flatten() function. We recreate the input with 4 samples this time:
sample4 = testarray([4, 2, 2, 3])
print(f"sample4 =\n{sample4}")
sample4 =
tensor([[[[ 0, 1, 2],
[ 10, 11, 12]],
[[ 100, 101, 102],
[ 110, 111, 112]]],
[[[1000, 1001, 1002],
[1010, 1011, 1012]],
[[1100, 1101, 1102],
[1110, 1111, 1112]]],
[[[2000, 2001, 2002],
[2010, 2011, 2012]],
[[2100, 2101, 2102],
[2110, 2111, 2112]]],
[[[3000, 3001, 3002],
[3010, 3011, 3012]],
[[3100, 3101, 3102],
[3110, 3111, 3112]]]], dtype=torch.int32)
Now we apply the flatten():
sampleFlatten = torch.flatten(sample4, start_dim=1)
print(f"sampleFlatten =\n{sampleFlatten}")
print(f"sampleFlatten.shape =\n{sampleFlatten.shape}")
sampleFlatten =
tensor([[ 0, 1, 2, 10, 11, 12, 100, 101, 102, 110, 111, 112],
[1000, 1001, 1002, 1010, 1011, 1012, 1100, 1101, 1102, 1110, 1111, 1112],
[2000, 2001, 2002, 2010, 2011, 2012, 2100, 2101, 2102, 2110, 2111, 2112],
[3000, 3001, 3002, 3010, 3011, 3012, 3100, 3101, 3102, 3110, 3111, 3112]],
dtype=torch.int32)
sampleFlatten.shape =
torch.Size([4, 12])
So the flattening method works the same way with that method. We get 4 rows with consistent first dimensions and the other dimensions are “unfurled” in the order 3 - 2 - 1.
So either method works and turns out to give the same results.