UNET - Understanding code in section 4 - Train the model and 4.1 - Create Predicted Masks


EPOCHS = 40
VAL_SUBSPLITS = 5
BUFFER_SIZE = 500
BATCH_SIZE = 32
processed_image_ds.batch(BATCH_SIZE)
train_dataset = processed_image_ds.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
print(processed_image_ds.element_spec)
model_history = unet.fit(train_dataset, epochs=EPOCHS)
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

I would appreciate if someone could explain all the code above. Although I pass the assignment, understanding the code is not obvious even after googling quite a bit.

I am sure there are other students with the same doubts as I have. (there have been previous questions by others which have been left unanswered).

Thank you very much.

Hey @AntonioMaher,
Let’s try to understand the code piece-by-piece.

I think this part should be pretty obvious. Just defining some variables for training purposes.

processed_image_ds.batch(BATCH_SIZE)

I don’t think the above line of code has any use since the next line of code also performs this operation. I guess they only included this so that the learners can find this method more easily and know that this method can be used as a separate method as well. Other than that, no use case of this line of code comes to my mind.

train_dataset = processed_image_ds.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Now, here I am assuming we agree on the fact that preprocessed_image_df is an instance of tf.data.Dataset(). If not, please check out section 2.1. Now, you can easily find all these attributes, methods and their use-cases, such as cache(), shuffle(), batch(), etc in the official documentation of tensorflow here.

Also, I would like to highlight one specific note from the cache() method of this class, from the docs of Tensorflow only.

Note: cache will produce exactly the same elements during each iteration through the dataset. If you wish to randomize the iteration order, make sure to call shuffle after calling cache.

I guess that should suffice for the first code cell. Let’s move on the create_mask function now. Let me take you on a journey here among some code pieces to help you understand this.

for image, mask in processed_image_ds.take(1):
    sample_image, sample_mask = image, mask
    print(mask.shape)
display([sample_image, sample_mask])

Please find this piece of code in your notebook. Here, you will find that mask.shape = (96, 128, 1). Now, we have written a display function to display this mask, and it is being displayed nicely. So, we want the predicted mask by the U-Net model to also have the same shape. Now, let’s go to the code cell containing unet.summary(), and let’s check out the output shape of the last layer, i.e., conv2d_65, and you will find it as (None, 96, 128, 23). I am assuming that you are well aware of the fact that None here is a symbolic representation of the batch-size (in general None is used whenever we don’t know the shape for a particular dimension in prior, and batch-size is usually one of those dimensions). If not, please check out the description of the shape argument of tf.keras.Input(), which you can find out here.

I have mentioned all this to just to make one small point, which is, when we will use this model to predict the mask for a singe image, the predicted mask will have the shape as (1, 96, 128, 23), which is exactly what is being done in section 4.3 in the show_predictions function. In other words, pred_mask has a shape of (1, 96, 128, 23). And now, you can simply run the following example code

import tensorflow as tf
import numpy as np

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

pred_mask = np.random.randn(1, 96, 128, 23)
print(pred_mask.shape)

new_mask = create_mask(pred_mask)
print(new_mask.shape)

to see how create_mask transforms an input having a shape of (1, 96, 128, 23) into an output having a shape of (96, 128, 1). However, you will find that the importance of this function is more than just altering the shape of the input. I am assuming you are familiar with the fact that U-Net predicts the classes for each of the pixels, so in this example, the model is classifying each of the pixels into one of the 23 classes, and in order to know to which class a pixel belongs to, we use tf.argmax(pred_mask, axis=-1), i.e., taking the maximum along the last dimension (representing the 23 classes), giving an output of shape (1, 96, 128), and now to make sure, it matches the shape of the mask, we add a new dimension in the last using tf.newaxis, i.e., pred_mask = pred_mask[..., tf.newaxis], giving an output of shape (1, 96, 128, 1), and finally, it simply selects the first mask from the input, i.e., pred_mask[0], giving an output of shape (96, 128, 1), and voila, we are done!

I hope this helps.

Cheers,
Elemento

2 Likes

Dear Elemento,

Thank you very much for the careful and detailed explanation. Much appreciated.

It is clear now!

It is amazing the time you and other mentors are willing to put in to help others.

:+1:

Hey @AntonioMaher,
I am glad I could help. Happy learning :nerd_face:

Cheers,
Elemento