UNet Assignment

Hi, trying to understand the output of the UNET. The shape of the output is (None, 96, 128, 23). The last value of 23 represents the number of classes. How are the classes mapped to the color of the mask when displaying? Just wondering how color is determined and rendered for the predicted mask. Also what is the purpose of the create_make function using tf.argmax as I am not sure what that function does? I notice its output of is (96, 128, 1). Not sure why it is not (96,128,3) because the mask is colored. Any ideas? Thank you.

Hi @Richard_Lai,

These are interesting questions. Looking at the code cells for the display and create_mask functions, I also had the same questions as you had. I think you asked two separate questions.

  • The create_mask function uses tf.argmax to select the class with the highest prediction value (among the 23 possible classes). Then the function returns one image with the shape (96, 128, 1), where each pixel value is the predicted class number.

  • The color that gets displayed for each pixel in a predicted mask image is determined in the plt.imshow call. The documentation for the matplotlib.pyplot.imshow function tells us that a pseudo-color image will be rendered if we pass a (96, 128, 1) image. The actual assigned colors are determined by the colormap. Had we passed a (96, 128, 3) image to the imshow function, the actual colors specified the 3 channels would have been displayed instead of using the colormap.

Hope I answered your questions.