Training on 2700 medical images (3D) and validating on 600. I’m using 4 v100 gpus on linux VM with 32 cores and 128gb mem. Using batch_size 8 and 50 epochs.
Originally, I was passing the batch tensors in the dict returned by training_step like the docs suggest, but I would notice that midway through my first epoch, training would just stop and hang, and the iteration time went for 2.5s to 15s then 40s etc…
I thought this was because the train_outputs dict was getting huge and slowing things down, so I’m wondering if there is a better way to store the preds/labels. I tried using the Metrics API and calculating epoch AUC through there but same thing would happen, training just stops and hangs.
When I remove the collection of preds/labels and just calculate epoch loss everything runs smoothly.