Hey @AntonioMaher,
Let’s try to understand the code piece-by-piece.
I think this part should be pretty obvious. Just defining some variables for training purposes.
processed_image_ds.batch(BATCH_SIZE)
I don’t think the above line of code has any use since the next line of code also performs this operation. I guess they only included this so that the learners can find this method more easily and know that this method can be used as a separate method as well. Other than that, no use case of this line of code comes to my mind.
train_dataset = processed_image_ds.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Now, here I am assuming we agree on the fact that preprocessed_image_df
is an instance of tf.data.Dataset()
. If not, please check out section 2.1. Now, you can easily find all these attributes, methods and their use-cases, such as cache()
, shuffle()
, batch()
, etc in the official documentation of tensorflow here.
Also, I would like to highlight one specific note from the cache()
method of this class, from the docs of Tensorflow only.
Note: cache
will produce exactly the same elements during each iteration through the dataset. If you wish to randomize the iteration order, make sure to call shuffle
after calling cache
.
I guess that should suffice for the first code cell. Let’s move on the create_mask
function now. Let me take you on a journey here among some code pieces to help you understand this.
for image, mask in processed_image_ds.take(1):
sample_image, sample_mask = image, mask
print(mask.shape)
display([sample_image, sample_mask])
Please find this piece of code in your notebook. Here, you will find that mask.shape = (96, 128, 1)
. Now, we have written a display
function to display this mask, and it is being displayed nicely. So, we want the predicted mask by the U-Net model to also have the same shape. Now, let’s go to the code cell containing unet.summary()
, and let’s check out the output shape of the last layer, i.e., conv2d_65
, and you will find it as (None, 96, 128, 23)
. I am assuming that you are well aware of the fact that None
here is a symbolic representation of the batch-size (in general None
is used whenever we don’t know the shape for a particular dimension in prior, and batch-size is usually one of those dimensions). If not, please check out the description of the shape
argument of tf.keras.Input()
, which you can find out here.
I have mentioned all this to just to make one small point, which is, when we will use this model to predict the mask for a singe image, the predicted mask will have the shape as (1, 96, 128, 23)
, which is exactly what is being done in section 4.3 in the show_predictions
function. In other words, pred_mask
has a shape of (1, 96, 128, 23)
. And now, you can simply run the following example code
import tensorflow as tf
import numpy as np
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
pred_mask = np.random.randn(1, 96, 128, 23)
print(pred_mask.shape)
new_mask = create_mask(pred_mask)
print(new_mask.shape)
to see how create_mask
transforms an input having a shape of (1, 96, 128, 23)
into an output having a shape of (96, 128, 1)
. However, you will find that the importance of this function is more than just altering the shape of the input. I am assuming you are familiar with the fact that U-Net predicts the classes for each of the pixels, so in this example, the model is classifying each of the pixels into one of the 23 classes, and in order to know to which class a pixel belongs to, we use tf.argmax(pred_mask, axis=-1)
, i.e., taking the maximum along the last dimension (representing the 23 classes), giving an output of shape (1, 96, 128)
, and now to make sure, it matches the shape of the mask
, we add a new dimension in the last using tf.newaxis
, i.e., pred_mask = pred_mask[..., tf.newaxis]
, giving an output of shape (1, 96, 128, 1)
, and finally, it simply selects the first mask from the input, i.e., pred_mask[0]
, giving an output of shape (96, 128, 1)
, and voila, we are done!
I hope this helps.
Cheers,
Elemento