Checkpoint model predictions not same as original model

My model class code is shown below

class FFNPL(pl.LightningModule):
    def __init__(self, prm):
        super(FFNPL, self).__init__()
        self.model = FFN(prm["model"])
        self.lr = prm["lr"]
        self.save_hyperparameters()
    
    def forward(self, x):
        return self.model.forward(x)
    
    def step(self, batch):
        X,y = batch
        yhat = self.forward(X)
        loss = F.binary_cross_entropy_with_logits(yhat,y)
        return loss,yhat,y
    
    def training_step(self, batch, batch_idx):
        loss,yhat,y = self.step(batch)
        return dict(loss=loss)
    
    def validation_step(self, batch, batch_idx):
        loss,yhat,y = self.step(batch)
        return dict(val_loss=loss, y=y, yhat=yhat)
    
    def validation_epoch_end(self, outputs):
        y = torch.cat([x["y"] for x in outputs])
        yhat = torch.sigmoid(torch.cat([x["yhat"] for x in outputs]))
        auc,ap = plf.classification.auroc(yhat,y,pos_label=1),plf.average_precision(yhat,y,pos_label=1)
        self.log("val_ap", ap)
        print(f"  Epoch {self.current_epoch:>2}  val_auc: {auc:.3%},  val_ap: {ap:.3%}")
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

Calls to trainer are shown below.

pl.seed_everything(42)
prm = {"model":{"input_dim":smat_train.shape[1], "hidden_dim":(256,128,32), "dropout_prob":(.5,.5,.5)}, "lr":1E-4}
modelpl = FFNPL(prm)

trainDL = utils.data.DataLoader(SparseData(smat_train, y_train), batch_size=128, num_workers=6, shuffle=True , pin_memory=True)
validDL = utils.data.DataLoader(SparseData(smat_valid, y_valid), batch_size=256, num_workers=6, shuffle=False, pin_memory=True)
testDL  = utils.data.DataLoader(SparseData(smat_test , y_test) , batch_size=256, num_workers=6, shuffle=False, pin_memory=True)

checkpoint_CB = plc.ModelCheckpoint(monitor="val_ap", save_top_k=15, mode="max", dirpath="checkpoints", filename="ffn-{epoch:02d}")
earlystopping_CB = plc.early_stopping.EarlyStopping(monitor="val_ap", patience=5, mode="max")
trainer = pl.Trainer(weights_summary=None, gpus=1, auto_scale_batch_size=None, deterministic=True, max_epochs=20, logger=False,
            progress_bar_refresh_rate=0, callbacks=[checkpoint_CB,earlystopping_CB])
trainer.fit(model=modelpl, train_dataloader=trainDL, val_dataloaders=validDL)

Trying to re-create the auc & ap on validation data

for n in range(13):
    #trainer.checkpoint_callback.best_model_path
    s = f"checkpoints/ffn-epoch={n:02}.ckpt"
    model = FFNPL(prm).load_from_checkpoint(s)
    Xlst = []; ylst = []
    for X,y in validDL:
        Xlst.append(model.forward(X))
        ylst.append(y)
    yhat = torch.sigmoid(torch.cat(Xlst)) #.detach().numpy()
    y = torch.cat(ylst) #.detach().numpy()
    print(f"Epoch {n:2} val_auc: {plf.classification.auroc(yhat,y,pos_label=1):.3%}, val_ap: {plf.average_precision(yhat,y,pos_label=1):.3%}")

Here’s the output for in-memory

Epoch  0  val_auc: 53.812%,  val_ap: 54.072%
Epoch  0  val_auc: 94.913%,  val_ap: 95.159%
Epoch  1  val_auc: 96.630%,  val_ap: 96.798%
Epoch  2  val_auc: 97.065%,  val_ap: 97.246%
Epoch  3  val_auc: 97.275%,  val_ap: 97.458%
Epoch  4  val_auc: 97.377%,  val_ap: 97.581%
Epoch  5  val_auc: 97.453%,  val_ap: 97.655%
Epoch  6  val_auc: 97.462%,  val_ap: 97.677%
Epoch  7  val_auc: 97.511%,  val_ap: 97.724%
Epoch  8  val_auc: 97.466%,  val_ap: 97.717%
Epoch  9  val_auc: 97.471%,  val_ap: 97.717%
Epoch 10  val_auc: 97.429%,  val_ap: 97.669%
Epoch 11  val_auc: 97.433%,  val_ap: 97.685%
Epoch 12  val_auc: 97.375%,  val_ap: 97.643%

Here’s the output for model loaded from checkpoint file.

Epoch  0  val_auc: 93.226%, val_ap: 93.606%
Epoch  1  val_auc: 95.662%, val_ap: 95.941%
Epoch  2  val_auc: 96.314%, val_ap: 96.564%
Epoch  3  val_auc: 96.603%, val_ap: 96.857%
Epoch  4  val_auc: 96.791%, val_ap: 97.044%
Epoch  5  val_auc: 96.907%, val_ap: 97.171%
Epoch  6  val_auc: 96.945%, val_ap: 97.204%
Epoch  7  val_auc: 96.906%, val_ap: 97.217%
Epoch  8  val_auc: 96.892%, val_ap: 97.191%
Epoch  9  val_auc: 96.911%, val_ap: 97.236%
Epoch 10  val_auc: 96.912%, val_ap: 97.173%
Epoch 11  val_auc: 96.865%, val_ap: 97.171%
Epoch 12  val_auc: 96.897%, val_ap: 97.242%

I am looking for guidance on what I may be messing up. I would expect these to be much closer for corresponding epochs.

Found the issue. The gradients have to be explicitly turned off after the checkpointed model is loaded! lightning puts it in eval mode in the valid and test methods as noted here https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html but one has to do it explicitly in the above situation. For the sake of completeness:

model = FFNPL.load_from_checkpoint(s) 
model.eval()
...

actually model.eval() doesn’t turn off the gradients, but toggle the behavior of BatchNorm and Dropout layers since they behave differently while training/evaluation. To turn off gradients you need to use torch.no_grad(). In short use model.eval() and torch.no_grad().

Also do model = FFNPL.load_from_checkpoint(s, param=prm) since load_from_checkpoint is a class method.

1 Like