Logging only on epochs not working as intended?

Hi, I’m trying to log metrics only on epochs, but it doesn’t seem to work as intended.

Here is my code:

class StackedLSTM(pl.LightningModule):
    
...

    def forward(self, x):
        out, _ = self.lstm1(x)
        out, _ = self.lstm2(self.dropout(out))
        logits = self.fc(self.dropout(out))[:, -1]
        return logits

    # custom logging function that I use for train/valid/test
    def log_metrics(self, loss, y, probas, stage):
        acc = accuracy_score(y, probas > 0.5)
        ap = average_precision_score(y, probas, average='weighted', pos_label=1)
        f1 = f1_score(y, probas > 0.5, average='weighted', pos_label=1)
        auroc = roc_auc_score(y, probas, average='weighted')

        self.log_dict({
            f'{stage}_loss': loss,
            f'{stage}_acc': torch.tensor(acc),
            f'{stage}_ap': torch.tensor(ap),
            f'{stage}_f1': torch.tensor(f1),
            f'{stage}_auroc': torch.tensor(auroc)
        }, on_step=False, on_epoch=True, logger=True)

        if stage == 'test':
            return acc, ap, f1, auroc

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        probas = torch.sigmoid(logits).detach().cpu().numpy()
        _y = y.detach().cpu().numpy()

        loss = nn.BCEWithLogitsLoss()(logits, y)
        self.log_metrics(loss, _y, probas, stage='train')
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        probas = torch.sigmoid(logits).detach().cpu().numpy()
        _y = y.detach().cpu().numpy()

        loss = nn.BCEWithLogitsLoss()(logits, y)
        self.log_metrics(loss, _y, probas, stage='valid')
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        probas = torch.sigmoid(logits).detach().cpu().numpy()
        _y = y.detach().cpu().numpy()

        loss = nn.BCEWithLogitsLoss()(logits, y)
        acc, ap, f1, auroc = self.log_metrics(loss, _y, probas, stage='test')
        cm = confusion_matrix(_y, probas > 0.5, normalize=None)
        return {'acc': acc, 'ap': ap, 'f1': f1, 'auroc': auroc}, cm

    def test_step_end(self, outputs):
        self.results = outputs

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

        return optimizer

I define my own log_metrics function and call it in different stages. When I run this, although I specify on_step = False and on_epoch = True, the logging happens for each step. It only works correctly when I pass the full dataset as a batch (1 batch = epoch).

Am I missing something here? I’d much appreciate any help!

Okay, so it seems that logging actually happens on epochs, but the x-axis on the plots are in steps instead of epochs.

Maybe this has to do with TensorboardLogger…
Edit: found a PR that talks about this issue #3228

1 Like