I have updated the notebook. get_data_loaders_with_validation function now correctly splits the dataset and applies the transformations accordingly.
For reference, here’s the update:
def get_data_loaders_with_validation(batch_size, val_fraction=0.1):
"""Creates and returns data loaders for training, validation, and testing.
Args:
batch_size: The number of samples per batch in each data loader.
val_fraction: The fraction of the training data to use for validation.
Returns:
A tuple containing the training, validation, and test data loaders.
"""
# Define the transformations for the training data, including augmentation.
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Define the transformations for the validation and test data (no augmentation).
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load the full CIFAR-10 training dataset with the TRAIN transform.
full_trainset = datasets.CIFAR10(root='./cifar10', train=True, download=True, transform=transform_train)
# Load the full CIFAR-10 training dataset again with the TEST transform (for validation).
# You need a separate object so the validation data doesn't get augmented.
full_valset = datasets.CIFAR10(root='./cifar10', train=True, download=True, transform=transform_test)
# Calculate the number of samples for the training and validation sets.
total_train = len(full_trainset)
val_size = int(val_fraction * total_train)
train_size = total_train - val_size
# Perform the split to generate random indices.
# You use full_trainset to generate the split, but you will apply indices to the correct backends below.
train_subset_temp, val_subset_temp = random_split(full_trainset, [train_size, val_size])
# Create the final train_set and val_set using the specific parent datasets and the generated indices.
# This ensures train_set uses 'full_trainset' (augmented) and val_set uses 'full_valset' (not augmented).
train_set = Subset(full_trainset, train_subset_temp.indices)
val_set = Subset(full_valset, val_subset_temp.indices)
# Load the CIFAR-10 test dataset.
test_set = datasets.CIFAR10(root='./cifar10', train=False, download=True, transform=transform_test)
# Create DataLoader instances for each dataset split.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
# Return the created data loaders.
return train_loader, val_loader, test_loader

