Trouble loading checkpoints (?)

Hi,

I’m having trouble understanding what is wrong with a multi-label image classifier I’ve built. The classifier achieves pretty decent performance when using 5-fold cross-validation (0.86 AUC). However, when loading the beck checkpoint from one of the fold, and trying to make predictions on the test set, performance poor (0.50, basically random). Even more shocking, after loading the best checkpoint for one of the folds, predictions on the full training set appear to be random (0.50 AUC).

I’m starting to wonder if I’m doing something wrong when trying to load the checkpoint, as everything else looks alright :confused:

class ImageClassifier(pl.LightningModule):
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        pretrained=False,
        **kwargs,
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.best_train_metric = None
        self.best_valid_metric = None

        self.model = create_model(
            model_name=self.hparams.arch,
            pretrained=pretrained,
            num_classes=num_classes,
            in_chans=in_channels,
        )

    def forward(self, x):
        x = self.model(torch.as_tensor(data=x))
        return x

    def configure_optimizers(self):
        optimizer = optimizer_factory(
            params=self.parameters(), hparams=self.hparams
        )

        scheduler = lr_scheduler_factory(
            optimizer=optimizer,
            hparams=self.hparams,
            data_loader=self.train_dataloader(),
        )
        return [optimizer], [scheduler]

    def compute_loss(self, y_hat, y):
        loss_fn = loss_factory(name=self.hparams.loss)
        loss = loss_fn(y_hat, y)
        return loss

    def compute_metric(self, y_hat, y):
        metric_fn = metric_factory(name=self.hparams.metric)
        try:  # if GPU metric
            metric = metric_fn(y_true=y, y_score=y_hat)
        except TypeError:  # if sklearn metric
            try:
                metric = metric_fn(
                    y_true=y.detach().cpu().numpy(),
                    y_score=y_hat.detach().cpu().numpy(),
                )
            except ValueError:
                metric = 0.50
        return metric

    def step(self, batch):
        x, y = batch
        y_hat = self(x)
        loss = self.compute_loss(y_hat=y_hat, y=y)

        return loss, y, y_hat.sigmoid()

    def training_step(self, batch, batch_idx):
        loss, y, y_hat = self.step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return {
            "loss": loss,
            "y_hat": y_hat,
            "y": y,
        }

    def training_epoch_end(self, outputs: List):
        y_hat = torch.cat([out["y_hat"] for out in outputs], dim=0)
        y = torch.cat([out["y"] for out in outputs], dim=0)

        train_metric = self.compute_metric(y_hat=y_hat, y=y)
        self.log("train_metric", train_metric)

    def validation_step(self, batch, batch_idx):
        loss, y, y_hat = self.step(batch)
        self.log("valid_loss", loss, on_step=True, on_epoch=True)
        return {"valid_loss": loss, "y_hat": y_hat, "y": y}

    def validation_epoch_end(self, outputs: List):
        y_hat = torch.cat([out["y_hat"] for out in outputs], dim=0)
        y = torch.cat([out["y"] for out in outputs], dim=0)

        valid_metric = self.compute_metric(y_hat=y_hat, y=y)
        self.log("valid_metric", valid_metric)

    def predict(self, dl):
        self.eval()
        self.to("cuda")

        for batch in dl():
            x = batch.float()
            x = x.to("cuda")
            with torch.no_grad():
                y_hat = self(x)
                yield y_hat.detach().cpu().numpy()

    def predict_proba(self, dl):
        self.eval()
        self.to("cuda")

        for batch in dl():
            x = batch.float()
            x = x.to("cuda")
            with torch.no_grad():
                y_hat = self(x)
                outs = y_hat.sigmoid()
                yield outs.detach().cpu().numpy()

Here is how I set up the callbacks and Trainer:

    checkpoint_callback = callbacks.ModelCheckpoint(
        monitor="valid_metric",
        mode="max",
        dirpath=constants.models_path,
        filename=f"arch={hparams.arch}_sz={hparams.sz}_fold={hparams.fold}",
        save_weights_only=True,
    )

    trainer = pl.Trainer(
        gpus=1,
        precision=hparams.precision,
        auto_lr_find=hparams.auto_lr,
        auto_scale_batch_size=hparams.auto_batch_size,
        max_epochs=hparams.epochs,
        callbacks=[checkpoint_callback],
    )

… and finally, here is how I’m loading the checkpoint before making some inference.

checkpoint_path = Path('../models/arch=resnest14d_sz=128_fold=1.ckpt')
model = learner.ImageClassifier.load_from_checkpoint(
    checkpoint_path=checkpoint_path,
    pretrained=False
)

Just for reference, this is the output of one fold…

GPU available: True, used: True                                                                                                                                                                            
TPU available: None, using: 0 TPU cores                                                                                                                                                                    
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]                                                                                                                                                                  
                                                                                                                                                                                                           
  | Name  | Type   | Params                                                                                                                                                                                
---------------------------------                                                                                                                                                                          
0 | model | ResNet | 8.6 M                                                                                                                                                                                 
---------------------------------                                                                                                                                                                          
8.6 M     Trainable params                                                                                                                                                                                 
0         Non-trainable params                                                                                                                                                                             
8.6 M     Total params                                                                                                                                                                                     
Epoch 1 // train loss: 0.2380, train metric: 0.8133, valid loss: 0.2655, valid metric: 0.8048                                                                                                              
Epoch 2 // train loss: 0.2179, train metric: 0.8333, valid loss: 0.2524, valid metric: 0.8311                                                                                                              
Epoch 3 // train loss: 0.2035, train metric: 0.8453, valid loss: 0.2519, valid metric: 0.8435                                                                                                              
Epoch 4 // train loss: 0.2164, train metric: 0.8540, valid loss: 0.2533, valid metric: 0.8497                                                                                                              
Epoch 5 // train loss: 0.1730, train metric: 0.8627, valid loss: 0.2500, valid metric: 0.8532                                                                                                              
Epoch 6 // train loss: 0.2098, train metric: 0.8719, valid loss: 0.2454, valid metric: 0.8575                                                                                                              
Epoch 7 // train loss: 0.2150, train metric: 0.8776, valid loss: 0.2614, valid metric: 0.8620                                                                                                              
Epoch 8 // train loss: 0.1875, train metric: 0.8839, valid loss: 0.2452, valid metric: 0.8638                                                                                                              
Epoch 9 // train loss: 0.1770, train metric: 0.8903, valid loss: 0.2412, valid metric: 0.8637                                                                                                              
Epoch 10 // train loss: 0.1656, train metric: 0.8956, valid loss: 0.2425, valid metric: 0.8719                                                                                                             
Epoch 11 // train loss: 0.1882, train metric: 0.9035, valid loss: 0.2490, valid metric: 0.8707                                                                                                             
Epoch 12 // train loss: 0.1955, train metric: 0.9055, valid loss: 0.2528, valid metric: 0.8666                                                                                                             
Epoch 13 // train loss: 0.1510, train metric: 0.9115, valid loss: 0.2439, valid metric: 0.8702                                                                                                             
Epoch 14 // train loss: 0.1599, train metric: 0.9159, valid loss: 0.2463, valid metric: 0.8709

If that can be of any help, I’m using BCEWithLogitsLoss as a loss function and roc_auc_score (macro) as a metric.

Are you able to spot anything wrong with my code?

After further investigations, I don’t think the issue is with loading the weights.

This is even more confusing, as I’m using the same exact evaluation metrics during training and to evaluate the final predictions. Maybe the error is in how I compute the predictions (?).

I’ve also tried to validate the logic in the predict() method…

Finally, I’ve also tried to remove the custom filename pattern I use for the checkpoint, just to ensure the ModelCheckpoint is working fine and saving the best checkpoint given my metric. That also doesn’t seem to be the issue.

@iamgianluca Is it still unresolved. I’m facing a very similar issue when trying to load my trained PyG graph model and predict from it. I suspect there should be something wrong with pytorch lightning module. Which version were you using actually?