UNET Notebook, train test

Hi, I have added a test set in the UNET notebook, because I was impressed about how good the segmentation was, so I wanted to make sure it was not overfitting.

I don´t know how to add the jupyter notebook here, sorry, so I pasted below, the code. The training is indeed overfitting (as the plots that I added show), and it was nice to see the moment it started to overfit.

I appreciate all the work from the professors to elaborate all the notebooks and all the course material. They are all great! Thanks a lot!

I hope someone finds the code and the discussion helpful.

from skimage import io
import numpy as np
from torch.utils.data import Subset

volumes = torch.Tensor(io.imread(‘train-volume.tif’))[:, None, :, :] / 255
labels = torch.Tensor(io.imread(‘train-labels.tif’, plugin=“tifffile”))[:, None, :, :] / 255
labels = crop(labels, torch.Size([len(labels), 1, target_shape, target_shape]))
total_dataset = torch.utils.data.TensorDataset(volumes, labels)

indices = np.random.permutation(range(len(total_dataset)))
train_fraction = 0.8
cut = int(len(total_dataset) * train_fraction)
dataset = Subset(total_dataset, indices[:cut])
test_dataset = Subset(total_dataset, indices[cut:])

def train():
dataloader = DataLoader(
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
unet = UNet(input_dim, label_dim).to(device)
unet_opt = torch.optim.Adam(unet.parameters(), lr=lr)
cur_step = 0

train_loss = []
test_loss = []
for epoch in range(n_epochs):
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real)
        # Flatten the image
        real = real.to(device)
        labels = labels.to(device)

        ### Update U-Net ###
        pred = unet(real)
        unet_loss = criterion(pred, labels)
        ### Evaluate on test set
        with torch.no_grad():
            real_test, labels_test = next(iter(testloader))
            real_test = real_test.to(device)
            labels_test = labels_test.to(device)
            pred_test = unet(real_test)
            unet_loss = criterion(pred_test, labels_test).detach()

        if cur_step % display_step == 0:
            print(f"Epoch {epoch}: Step {cur_step}: U-Net loss: {unet_loss.item()}")
                crop(real, torch.Size([len(real), 1, target_shape, target_shape])), 
                size=(input_dim, target_shape, target_shape)
            show_tensor_images(labels, size=(label_dim, target_shape, target_shape))
            show_tensor_images(torch.sigmoid(pred), size=(label_dim, target_shape, target_shape))
            plt.plot(train_loss, label = 'train loss')
            plt.plot(test_loss, label = 'test loss')
        cur_step += 1


Hi Erick!
Thanks for you contribution, have you thought about adding this to a github project? This is a really great way to share work with the open source community.