Am trying & failing to fully understand how the following code works.
def format_image(image, label):
This simple function is defined as taking two arguments (and running it, shows that it receives tensorflow.python.framework.ops.Tensor
tensors).
But the following line seems to pass train_examples
, a tf.python.data.ops.dataset_ops.PrefetchDataset
to format_image()
(via the map()
function) like so:
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE).prefetch(1)
When I do:
for element in train_examples.as_numpy_iterator():
print(type(element))
break
I get <class 'tuple'>
I (very) tentatively concluded that the map() line is passing a tuple of tensors to the function (somehow …wrapped inside a tf.data.Dataset … ?), and not as two separate tensors.
But if this is the case why does Python allow this? The function is expecting two arguments, not a single tuple.
So, if I do (while commenting out the tf.image.resize()
line):
format_image('confused', 1)
I get (as expected) no errors.
But if I do:
my_tuple = ('confused', 1)
format_image(my_tuple)
I get (as expected):
TypeError: format_image() missing 1 required positional argument: 'label'
So, why does train_examples.map(format_image)
not also throw this TypeError ?
(Without being able to get such info from the documentation or even from print statements, I think these powerful libraries of Tensorflow are less accessible to newcomers, … I mean, beyond copy-pasting this stuff and seeing that it just works, without really knowing how, will probably cause problems somewhere down the line, imo).