Stable training loss but oscillating validation curves training ViT

I am training imaging data with ~1000 channels on a modified vision transformer model.

I am limited in the number of samples as I only have 10 images (~200x200x1000) available to me of which I have converted into patches yielding to around 15k patches each with an associated label and balanced dataset. I have also performed PCA on the channels to reduce the dimensionality. The current set contains 6 training, 2 validation and 2 test.

Currently, these are my training and validation curves of my best results so far. Patches generated for these results were sized 8x8x25 with 50% overlap:
training loss

val accuracy

val balanced accuracy

val f1

val loss

My problem is, understanding how to move forward with these results. It seems that the model is training and learning based on the validation metrics, however, it fluctuates a lot and I am not sure how to mitigate that.

What I have tried:

  • Different patches sizes (4x4, 8x8, 16x16 etc.)
  • Different channel size (8,16,32 etc…)
  • Different overlaps when generating patches (20%, 50%, etc.)
  • Lowering learning rate
  • Lowering weight decay
  • Balancing dataset

These are ways I tried to miligate the flucuating and improve the overall accuracy. However, instead it resulted in poorer performance such as plateauing early in training and more extreme flucuation.

2 Likes