I have been working with tensorflow for my CNN project till now. While switching to PyTorch, I’m unable to figure out a few things regarding the input pipeline. The training dataset consists of around 500 thousand images. In tensorflow, the tf.data.Dataset can be used to create an iterable which is fed to the model. We do not need to load every image into the iterable object in this process, thus saving us from exploding the RAM. In PyTorch, I’m unable to figure out how to carry out this task. Using Dataset from torch.utils.data loads each image and corresponding label into the iterable which takes a heavy toll on RAM. What is it that I’m missing?
If loading individual images inside __get_item__()
of Dataset
exhausts RAM, odds are good that the batch size of the DataLoader
is set to a high value.
1 Like
Thanks @balaji.ambresh . Actually the mistake was on my part. Instead of defining individual image transformations inside the __get_item__()
method, I did this inside the __init__()
method. Now as I have corrected this, the RAM issue is no longer there.
Thanks for confirming. Cheers.