Data augmentation for Semantic Segmentation. A Keras bug

Trying to incorporate data augmentation to the Semantic Segmentation assignment of C4 Wk 3, I encountered a major issue.
The Tf.keras data augmentation RandomFlips have a bug in tf version 2.13.0. They produce different flips in successive applications with the same seed.
The issue is discussed in Stack Overflow. There I read that 2.12.1 does not have that problem, while 2.13.0 is being worked on.
I could not install 2.12 over my installations, getting some strange incompatibility message. Instead, I could install the “nightly” Tf, using pip. However, the bug was still there.
The solution I found was to use tf.images data augmentation functions (code below). For Sem Segmentation, where the same transform must be applied to image and mask, one must use the “stateless” versions of the flippers (e.g., tf.image.stateless_random_flip_left_right.)
Below is code that works in the “preprocess” function, and can be applied to the image dataset.
Note some modest tricks I used to change the seed between iterations but keeping it the same for image and mask.
Note also that “seed” must be a range 2 tensor (just for the “stateless” versions).
Comments and suggestions on better ways to do this, will be greatly appreciated.

seed = (1, 2)

def process_path(image_path, mask_path):
img = tf.io.read_file(image_path)
img = tf.image.decode_png(img, channels=1)
img = tf.image.convert_image_dtype(img, tf.float32)

mask = tf.io.read_file(mask_path)
mask = tf.image.decode_png(mask, channels=0, dtype=tf.dtypes.uint16)
return img, mask

def preprocess(image, mask):
global seed
seed = (seed[0] + random.randint(0, 10000), seed[1] + random.randint(0, 10000))
#changes the seed randomly
input_image = tf.image.resize(image, (512, 512), method=‘nearest’)
input_mask = tf.image.resize(mask, (512, 512), method=‘nearest’)
input_image = tf.image.stateless_random_flip_left_right(input_image, seed)
input_mask = tf.image.stateless_random_flip_left_right(input_mask, seed)
input_image = tf.image.stateless_random_flip_up_down(input_image, seed)
input_mask = tf.image.stateless_random_flip_up_down(input_mask, seed)
return input_image, input_mask

image_ds = dataset.map(process_path)
processed_image_ds = image_ds.map(preprocess)

Thanks for posting your information.