Compute Precision Recall Curve without OOM

Hello.
I’m interested in training in GPU, computing metrics on the validation loop.
I need to use heavyweight metrics such as PrecisionRecallCurve and ROC over 512x512 images in a multilabel segmentation problem, the problem is that I easily get OOM.

I have computed the metrics correctly deactivating the GPU avaliable, the problem is that 80% of the time it is doing the forward() function of the model.

So nice, metrics are working perfectly with CPU, the .update() and .compute() steps are working properly.

The problem arrives when I activate the GPU, and initialize the metric for CPU, infer on GPU and calculate metric on CPU.

I did initialize the metric like this:

PrecisionRecallCurve(pos_label=1, num_classes=1).to('cpu')

and in step_end:

precision_recall_curve.update(image.to('cpu'),
                              mask.to('cpu')

So far so good, it seems that the metric is updating correctly.

The problem is:

precision, recall, threshold = precision_recall_curve.compute()

raises me the following exception:

work = _default_pg.allgather([tensor_list], [tensor])
RuntimeError: Tensors must be CUDA and dense

I guess that this kind of hybrid functionality it’s not supported.

I guess the only solution I have, is to compute the metrics per sample/batch individually and then perform some kind of aggregation since the whole validation set (which is not too big) does not fit in vram.

I would like to know if anyone has faced a similar problem and managed to solve it.

Thanks