Hi, I have added a test set in the UNET notebook, because I was impressed about how good the segmentation was, so I wanted to make sure it was not overfitting.
I don´t know how to add the jupyter notebook here, sorry, so I pasted below, the code. The training is indeed overfitting (as the plots that I added show), and it was nice to see the moment it started to overfit.
I appreciate all the work from the professors to elaborate all the notebooks and all the course material. They are all great! Thanks a lot!
I hope someone finds the code and the discussion helpful.
from skimage import io
import numpy as np
from torch.utils.data import Subset
volumes = torch.Tensor(io.imread(‘train-volume.tif’))[:, None, :, :] / 255
labels = torch.Tensor(io.imread(‘train-labels.tif’, plugin=“tifffile”))[:, None, :, :] / 255
labels = crop(labels, torch.Size([len(labels), 1, target_shape, target_shape]))
total_dataset = torch.utils.data.TensorDataset(volumes, labels)
indices = np.random.permutation(range(len(total_dataset)))
train_fraction = 0.8
cut = int(len(total_dataset) * train_fraction)
dataset = Subset(total_dataset, indices[:cut])
test_dataset = Subset(total_dataset, indices[cut:])
def train():
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
unet = UNet(input_dim, label_dim).to(device)
unet_opt = torch.optim.Adam(unet.parameters(), lr=lr)
cur_step = 0
train_loss = []
test_loss = []
for epoch in range(n_epochs):
for real, labels in tqdm(dataloader):
cur_batch_size = len(real)
# Flatten the image
real = real.to(device)
labels = labels.to(device)
### Update U-Net ###
unet_opt.zero_grad()
pred = unet(real)
unet_loss = criterion(pred, labels)
unet_loss.backward()
unet_opt.step()
train_loss.append(unet_loss)
### Evaluate on test set
with torch.no_grad():
real_test, labels_test = next(iter(testloader))
real_test = real_test.to(device)
labels_test = labels_test.to(device)
pred_test = unet(real_test)
unet_loss = criterion(pred_test, labels_test).detach()
test_loss.append(unet_loss)
if cur_step % display_step == 0:
print(f"Epoch {epoch}: Step {cur_step}: U-Net loss: {unet_loss.item()}")
show_tensor_images(
crop(real, torch.Size([len(real), 1, target_shape, target_shape])),
size=(input_dim, target_shape, target_shape)
)
show_tensor_images(labels, size=(label_dim, target_shape, target_shape))
show_tensor_images(torch.sigmoid(pred), size=(label_dim, target_shape, target_shape))
plt.figure()
plt.plot(train_loss, label = 'train loss')
plt.plot(test_loss, label = 'test loss')
plt.legend()
plt.show()
cur_step += 1
train()