Hi, I am writing a pytorch lightning module, with a prediction module to generate data.
The prediction module’s structure looks like follows:
def predict_step(self, batch, batch_idx, dataloader_idx: int = 0): labels, feats, mask = batch['labels'], batch['feats'], batch['mask'] # transformations over feats, labels, mask to generate tensors return simulation_out
Now the ‘on_predict_batch_end’ function looks like below:
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): simulation_out = outputs outs = transform(simulation_out) save outs
With the output generated from ‘predict_step’ I am doing some transformations which involve converting some torch tensors to numpy arrays and doing operations on top of it. I am saving the output from the simulation on disk, and deleting intermediate variables.
However, I am getting OOM error when I try to run this for large datasets. I am not sure what am I supposed to look for? I am deleting intermediate arrays from the ‘on_predict_batch_end’ step. What else should I look for debugging?