Week 3, programming assignment 2: how does this code snippet work?


I have trouble understanding how this code snippet from the image segmentation with U-Net programming exercise (2.2 - Preprocess Your Data) works. I understand what it does, but I don’t understand the code. (My problem is basically a Python issue, sorry if that is not appropriate here - but Googling didn’t help.)

def process_path(image_path, mask_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_png(img, channels=3) 
    img = tf.image.convert_image_dtype(img, tf.float32)
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=3)    
    mask = tf.math.reduce_max(mask, axis=-1, keepdims=True)

    return img, mask

def preprocess(image, mask):
    input_image = tf.image.resize(image, (96, 128), method='nearest')
    input_mask = tf.image.resize(mask, (96, 128), method='nearest')

    return input_image, input_mask

image_ds = dataset.map(process_path) 
processed_image_ds = image_ds.map(preprocess)

It’s the last two lines. Here, map() applies the functions process_path() and preprocess() to each element of dataset and image_ds, respectively - but I don’t understand why those functions do not take their positional arguments here. My guesses:

  • image_path and mask_path have been defined earlier in the notebook, so that’s why process_path() knows these global variables, and they don’t have to be provided using map().
  • image and mask, however, have not been defined. preprocess() knows what they are because, with map(), preprocess() is applied to each item of image_ds (which consists of an image and its mask), so it assumes that image is the first element of this item, and mask the second - and these variables also do not have to be provided using map().

Is that about right?

Yes, it’s the second point you suggest that is how it works.

The Object Oriented Programming is getting a little intense here. You have to look at the definition of the map() method of the tf.data.Dataset class, right? Here’s the documentation. It takes a function reference as its argument and here’s what it says about the function signature of that function:

The input signature of map_func is determined by the structure of each element in this dataset.

They give some examples in that section to show what they mean by that, but understanding them depends on knowing how “Lambda” functions work in python. But it’s what you said in your second point: it assumes the arguments are the elements of each sample in the dataset.

Thanks, Paul! It makes sense now. I now also see that it works the same for image_ds, whose items consist of string pairs.

Also - I was actually looking at the wrong map() function: the ‘regular’ one, not the one for tf.data.Dataset… That caused some unnecessary confusion.