Conv2DTranspose (which, I will subsequently call Conv2T) was used a number of times during explaining different image segmentation architectures. At first, I did not really understand how it worked, and only glossed over it. But, after going through the model summary, I was able to figure out how Conv2T worked, and I will be explaining it using the FCN-8 decoder architecture and code.
In order to understand how Conv2T works, you should be familiar with the way a typical Conv2D works. Conv2T is trying to reverse the effect of Conv2D.
Leaving out the channel of an input feature (whether it is an image or an intermediate activation function), Conv2D works on reducing the width and height of an input feature with the following formula.
`
output = ( (input - kernel_input)/strides ) + 1
`
where input is either weight or height of the input, and output is respectively weight or height after performing convolution. Note that this assumes that there is no padding, in order to simplify things.
With this formula in mind, if the input of an image is (7, 7). And you use a kernel size of (3, 3), and a stride of (2, 2). Our final output will be:
output = ( (7 - 3) / 2) + 1
output = 3 (The output feature will be (3, 3))
NOW, to Conv2T. Conv2T simply reverses the formula.
To avoid confusion, let us make input
the subject of the formula from the convolution equation above;
Thus,
input = (output - 1) * strides + kernel_input.
However, since we are reversing the process, the input
from actual Conv2D becomes the output
in Conv2T, and the output
from the Conv2D becomes the input
for Conv2T
Therefore the formula for obtaining the reversal process of Conv2T is properly given below:
output = (input - 1) * strides + kernel_input.
Let us try this formula with the example we did in Conv2D.
Our input is (3, 3), kernel size of (3, 3), and a stride of (2, 2).
output = ((3 - 1) * 2) + 3
output = 7 (Which is the original size of the image we wanted)
From what I have seen so far, I think it would be important to note that Conv2T is not a perfect way to reverse the process, but it works.
Let us demonstrate how this works with FCN-8 architecture.
In the typical FCN-8 encoder that we used (VGG-16, with some additional layers), we saw that image went from 224 to 112(p1) to 56(p2), to 28(p3), to 14(p4), and finally to 7(p5). See part of the model summary below, with the image and noted outputs boldened:
input_1 (InputLayer) [(None, 224, 224, 3 0
)]block1_conv1 (Conv2D) (None, 224, 224, 64 1792 [‘input_1[0][0]’]
)block1_conv2 (Conv2D) (None, 224, 224, 64 36928 [‘block1_conv1[0][0]’]
)block1_pool2 (MaxPooling2D) (None, 112, 112, 64 0 [‘block1_conv2[0][0]’]
)block2_conv1 (Conv2D) (None, 112, 112, 12 73856 [‘block1_pool2[0][0]’]
8)block2_conv2 (Conv2D) (None, 112, 112, 12 147584 [‘block2_conv1[0][0]’]
8)block2_pool2 (MaxPooling2D) (None, 56, 56, 128) 0 [‘block2_conv2[0][0]’]
block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 [‘block2_pool2[0][0]’]
block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 [‘block3_conv1[0][0]’]
block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 [‘block3_conv2[0][0]’]
block3_pool3 (MaxPooling2D) (None, 28, 28, 256) 0 [‘block3_conv3[0][0]’]
block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 [‘block3_pool3[0][0]’]
block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 [‘block4_conv1[0][0]’]
block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 [‘block4_conv2[0][0]’]
block4_pool3 (MaxPooling2D) (None, 14, 14, 512) 0 [‘block4_conv3[0][0]’]
block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 [‘block4_pool3[0][0]’]
block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 [‘block5_conv1[0][0]’]
block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 [‘block5_conv2[0][0]’]
block5_pool3 (MaxPooling2D) (None, 7, 7, 512) 0 [‘block5_conv3[0][0]’]
conv6 (Conv2D) (None, 7, 7, 4096) 102764544 [‘block5_pool3[0][0]’]
conv7 (Conv2D) (None, 7, 7, 4096) 16781312 [‘conv6[0][0]’]
Also note that the MaxPooling is the one doing the job of reducing the width and height, which is itself a Convolution Layer too.
From the FCN-8 decoder, we know that we first have to upsample p5 by 2x - This means that for each pixel, we should create 2 for it, and our desired output is a 14 x 14 result.
From the code:
tf.keras.layers.Conv2DTranspose(n_classes , kernel_size=(4,4) , strides=(2,2) , use_bias=False )
We see that we used a kernel_size of 4 and strides of 2.
From the Conv2T formula
output = (input - 1) * strides + kernel_input.
The output of this will be (7 - 1) * 2 + 4
, which equals 16. But our desired output is 14. This is why a cropping layer (tf.keras.layers.Cropping2D(cropping=(1,1))
) follows this layer to crop the edges, one from the width, and one from the height. And our final output becomes 14.
This output is then combined with p4, which has the same height and width as upsampled p5.
This 2x upsampling with Conv2T is also performed on the combination of p4 and p5 to get an output height and width of 28, which is then combined with p3.
This combined result is then 8x upsampled using the following filter; kernel_size=(8,8) , strides=(8,8)
Plugging this into our Conv2T formula, we get this
(28 - 1) * 8 + 8
And this equal 224
This gives us the final result, which is the final upsampled image into the same shape (Height and width) as the input image.
Notice, how this particular layer does not have a cropping layer to follow it.
See the summary of the decoder part below, with the major upsampling highlighted and boldened;
conv2d_transpose (Conv2DTransp (None, 16, 16, 12) 786432 [‘conv7[0][0]’]
ose)cropping2d (Cropping2D) (None, 14, 14, 12) 0 [‘conv2d_transpose[0][0]’]
conv2d (Conv2D) (None, 14, 14, 12) 6156 [‘block4_pool3[0][0]’]
add (Add) (None, 14, 14, 12) 0 [‘cropping2d[0][0]’,
‘conv2d[0][0]’]conv2d_transpose_1 (Conv2DTran (None, 30, 30, 12) 2304 [‘add[0][0]’]
spose)cropping2d_1 (Cropping2D) (None, 28, 28, 12) 0 [‘conv2d_transpose_1[0][0]’]
conv2d_1 (Conv2D) (None, 28, 28, 12) 3084 [‘block3_pool3[0][0]’]
add_1 (Add) (None, 28, 28, 12) 0 [‘cropping2d_1[0][0]’,
‘conv2d_1[0][0]’]conv2d_transpose_2 (Conv2DTran (None, 224, 224, 12 9216 [‘add_1[0][0]’]
spose) )
Thank you. I am open to any contribution to this post.