OOM error on batch prediction

Hi, I am writing a pytorch lightning module, with a prediction module to generate data.

The prediction module’s structure looks like follows:

def predict_step(self, batch, batch_idx, dataloader_idx: int = 0):
        labels, feats, mask = batch['labels'], batch['feats'], batch['mask']
        # transformations over feats, labels, mask to generate tensors
        return simulation_out

Now the ‘on_predict_batch_end’ function looks like below:

def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):
        simulation_out = outputs
        outs = transform(simulation_out)
        save outs

With the output generated from ‘predict_step’ I am doing some transformations which involve converting some torch tensors to numpy arrays and doing operations on top of it. I am saving the output from the simulation on disk, and deleting intermediate variables.

However, I am getting OOM error when I try to run this for large datasets. I am not sure what am I supposed to look for? I am deleting intermediate arrays from the ‘on_predict_batch_end’ step. What else should I look for debugging?