EarlyStopping not working as expected

I am trying to get early stopping to work in my code.

class FFNPL(pl.LightningModule):
    def __init__(self, prm):
        super(FFNPL, self).__init__()
        self.model = FFN(prm["model"])
        self.lr = prm["lr"]
    
    def forward(self, x):
        return self.model.forward(x)
    
    def step(self, batch):
        X,y = batch
        yhat = self.forward(X)
        loss = F.binary_cross_entropy_with_logits(yhat,y)
        return y,yhat,loss
    
    def training_step(self, batch, batch_idx):
        y,yhat,loss = self.step(batch)
        return dict(loss=loss)
    
    def validation_step(self, batch, batch_idx):
        y,yhat,loss = self.step(batch)
        return dict(val_loss=loss, y=y, yhat=yhat)
    
    def validation_epoch_end(self, outputs):
        y = torch.cat([x["y"] for x in outputs])
        yhat = torch.sigmoid(torch.cat([x["yhat"] for x in outputs]))
        auc,ap = plf.classification.auroc(yhat,y,pos_label=1),plf.average_precision(yhat,y,pos_label=1)
        self.log("val_ap", ap)
        print(f"  Epoch {self.current_epoch}  val_auc: {auc:.2%},  val_ap: {ap:.2%}")
    
    def test_step(self, batch, batch_idx):
        y,yhat,loss = self.step(batch)
        return dict(y=y, yhat=yhat)
    
    def test_epoch_end(self, outputs):
        y = torch.cat([x["y"] for x in outputs])
        yhat = torch.sigmoid(torch.cat([x["yhat"] for x in outputs]))
        print(f"Test: auc: {plf.classification.auroc(yhat,y):.2%},  ap: {plf.average_precision(yhat,y,pos_label=1):.2%}")
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

pl.seed_everything(42)
prm = {"model":{"input_dim":smat_train.shape[1], "hidden_dim":(256,256,256,64), "dropout_prob":(.5,.5,.5,.5,)}, "lr":1E-4}
model = FFNPL(prm)

trainer = pl.Trainer(auto_scale_batch_size="power", gpus=1, deterministic=True, max_epochs=20, progress_bar_refresh_rate=0, 
                        callbacks=[plc.early_stopping.EarlyStopping(monitor="val_ap", patience=3)])
trainer.fit(model=model, train_dataloader=trainDL, val_dataloaders=validDL)

When I comment out EarlyStopping, I get the following

  | Name  | Type | Params
-------------------------------
0 | model | FFN  | 5.4 M 
-------------------------------
5.4 M     Trainable params
0         Non-trainable params
5.4 M     Total params

  Epoch 0  val_auc: 47.87%,  val_ap: 48.39%
  Epoch 0  val_auc: 74.66%,  val_ap: 74.06%
  Epoch 1  val_auc: 94.57%,  val_ap: 94.72%
  Epoch 2  val_auc: 96.22%,  val_ap: 96.32%
  Epoch 3  val_auc: 96.77%,  val_ap: 96.89% <==
  Epoch 4  val_auc: 96.97%,  val_ap: 97.12%
  Epoch 5  val_auc: 97.19%,  val_ap: 97.35%
  Epoch 6  val_auc: 97.26%,  val_ap: 97.42%
  Epoch 7  val_auc: 97.33%,  val_ap: 97.49%
  Epoch 8  val_auc: 97.35%,  val_ap: 97.51%
  Epoch 9  val_auc: 97.36%,  val_ap: 97.52%
  Epoch 10  val_auc: 97.41%,  val_ap: 97.58%
  Epoch 11  val_auc: 97.41%,  val_ap: 97.57%
  Epoch 12  val_auc: 97.40%,  val_ap: 97.58%
  Epoch 13  val_auc: 97.40%,  val_ap: 97.59%
  Epoch 14  val_auc: 97.44%,  val_ap: 97.62%
  Epoch 15  val_auc: 97.42%,  val_ap: 97.61%
  Epoch 16  val_auc: 97.40%,  val_ap: 97.61%
  Epoch 17  val_auc: 97.43%,  val_ap: 97.63%
  Epoch 18  val_auc: 97.44%,  val_ap: 97.64%
  Epoch 19  val_auc: 97.39%,  val_ap: 97.61%

When EarlyStopping is turned on (as shown above), the training stop at Epoch 3. What am I messing up?

you need to set mode

EarlyStopping(monitor="val_ap", patience=3, mode='max')
2 Likes

Thanks. That worked. This is my first pl script. Would you write any part of this differently to make it more efficient. I need to change this now to use optuna and would like this to be as efficient as possible.

It looks good enough.