in this tutorial they show how to add a callback in order to access and log one batch of validation predictions.
However, I am struggling to understand how to access predictions for the entire validation dataset. At the end of each epoch, I want to create a confusion matrix which includes all validation examples. The only workaround I can think of is to
- define self.val_outputs variable in init() of the LightningModule
- add a callback on_validation_batch_end() and each time it is called, append the latest batch predictions to self.val_outputs
- add a callback on_validation_epoch_end() and compute the confusion matrix using self._val_outputs; at the end, set self.val_outputs =
However, this seems like an ugly workaround and I believe there must be an easier way to do this. Any advice please?