Using AUROC as a validation metric

I am using loss for optimization but I want to use AUROC from torchmetrics to monitor for early stopping. The code I have written is as follows:

    def validation_step(self, batch, batch_idx):
        x = batch['src']
        y = batch['label']
        mask = batch['mask']

        x = self.base_model(x, mask)
        x = self.linear(x).mean(axis=1).squeeze(1)
        
        loss = F.binary_cross_entropy_with_logits(input=x,
                               target=y)
        return {'loss': loss, 'preds': x, 'target': y}

    def validation_step_end(self, outputs):
        self.valid_auc(torch.sigmoid(outputs['preds']), outputs['target'].int())
        self.log('valid_auc', self.valid_auc)
        self.log('valid_loss', outputs['loss'])

The early stopping callback looks like:

        early_stopping_cb = EarlyStopping(
            monitor='valid_auc',
            min_delta=args.min_delta,
            patience=args.patience,
            mode='max',
            strict=True)

The question I have is whether AUROC is being aggregated correctly across mini-batches. Is there anything else I need to do? Is there a good way to validate that the sample being used to calculate AUROC is the entire validation set as opposed to the average of AUROC for each batch?

hey @EvanZ

if you set self.log(..., on_epoch=True) which is the default behavior for validation, it will use weighted average instead of the mean across all the batch-level metrics where weights are derived from batch_size of every batch.

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only after some time.

Thank you