Hi everyone,
I would like to report a subtle bug in the notebook C1_M3_Lab_data_management.ipynb:
Description
In this notebook, there is a logic conflict between the base FlowerDataset and the SubsetWithTransform wrapper. The base dataset is initialized with a transform pipeline that includes transforms.ToTensor(). When SubsetWithTransform is later applied to the split subsets, it attempts to apply a second transformation pipeline (the augmentation) to an object that has already been converted into a Tensor.
This results in a TypeError because many augmentation transforms (like RandomHorizontalFlip) expect a PIL Image or ndarray, but receive a torch.Tensor.
Steps to Reproduce
-
Initialize
FlowerDatasetwith a base transform that includesToTensor(). -
Split the dataset using
random_split. -
Wrap the training subsets in
SubsetWithTransformusing an augmentation pipeline (e.g.,RandomHorizontalFlip). -
Access an element:
train_dataset[0].
Note that this also applies to the validation and test dataset.
Technical Analysis
The issue lies in the nested call stack of the __getitem__ methods:
-
SubsetWithTransform.__getitem__callsself.subset[idx]. -
This triggers the base
FlowerDataset.__getitem__, which applies its internalself.transform. -
If the base transform includes
ToTensor(), the image is returned as a Tensor. -
SubsetWithTransformthen attempts to applyaugmentation_transformto this Tensor, causing the crash.
Error Traceback
TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
Suggested Fix
The base FlowerDataset should be re-initialized with transform=None when it is intended to be split and wrapped by SubsetWithTransform. This ensures that the raw PIL image is passed up the chain, allowing the wrapper to handle all transformations in a single pass.
Recommended Code Change:
# Initialize base dataset without transforms to avoid double-processing
dataset_raw = FlowerDataset(path_dataset, transform=None)
# ... perform split ...
# Apply specific transforms only at the subset level
train_dataset = SubsetWithTransform(train_indices, transform=augmentation_transform)
val_dataset = SubsetWithTransform(val_indices, transform=base_transform)
Note that the bug does not surface during normal execution of the notebook, since the subsets are not accessed after being defined. Nevertheless, I would encourage explicitly re-initializing FlowerDataset with transform=None, as this makes the double-transformation issue visible and helps students build a clearer mental model of the data pipeline.
I hope you find this helpful in further improving the resources of this great course.
Best regards,
Carl