An important miss from instructors

I was doing the graded assignment for week #1 ( Programming Assignment: Deep N-grams) when I stumbled upon a rather glaring miss from instructors on an topic of how datasets manage batching operations.

Concretely, while implementing create_batch_dataset function, it was very hard for me to unearth an importnat detail about batching. In this function we are being asked to batch our input sequences twice: 1.) to make sequence of words and 2.) to make batches of those sequences. In the middle of this two operations, we also convert the the sequence (after 1st batching) to a tuple ([:-1], [1:]) of a sequence. Doing this changes the shape of each entry of the dataset from [ (17,), (17,) …] to [ ((16,), (16)), ((16,), (16)), … ]

Now, intuitively anyone would think that when the next batching is done, it’ll just collect these tuples into batch size and that’s how dataset generator pattern would work.

But that isn’t so. During the second batching operation, the tf is unpacking these tuples and collecting all the first elements together and all the second elements together. This means, the shape isn’t as you would hope.

To take an example, let’s say our second batch size id 3. Then 1 would hope that running .take(1) (get 1 batch) would yield something like: a sequence of 3 tuples [ ( (16,), (16,) ), (), () ). But that is not the case. Confusingly enough the return of .take(1) would be a tuple of 2 tensors, where first tensor would accumulate all 3 first entries and 2nd element would accumulate all 3 second entries.

It took me several hours to figure this out and I think instructors should have explained this very clearly to avoid wasting effort here. Here’s a chatgpt explanation for this.

====

You’re absolutely correct in your observation that tf.data.Dataset.batch() does more than simply grouping elements together. When a dataset contains tuples, tf.data.Dataset.batch() treats each component of the tuple separately and batches them independently. This behavior is deliberate and built into TensorFlow’s batch() function to support common machine learning use cases like batching input-target pairs.