Hi, I am writing a pytorch lightning module, with a prediction module to generate data.
The prediction module’s structure looks like follows:
def predict_step(self, batch, batch_idx, dataloader_idx: int = 0):
labels, feats, mask = batch['labels'], batch['feats'], batch['mask']
# transformations over feats, labels, mask to generate tensors
return simulation_out
Now the ‘on_predict_batch_end’ function looks like below:
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):
simulation_out = outputs
outs = transform(simulation_out)
save outs
With the output generated from ‘predict_step’ I am doing some transformations which involve converting some torch tensors to numpy arrays and doing operations on top of it. I am saving the output from the simulation on disk, and deleting intermediate variables.
However, I am getting OOM error when I try to run this for large datasets. I am not sure what am I supposed to look for? I am deleting intermediate arrays from the ‘on_predict_batch_end’ step. What else should I look for debugging?