Hparams not restored when using load_from_checkpoint (default argument values are the problem?)

Problem

I’m having an issue where the model is training fine, and the saved checkpoint does indeed have the hparams used in training. When loading the model with MyModel.load_from_checkpoint(), however, these hparams are not restored.

Code breakdown

Sorry the following code is the minimum working version I could make that can be executed and replicates my issue. I’m trying to support various non-Lightning pre-trained PyTorch weights and models (which I hope to make open source).
Simply said what each class does:

  • ModelBase purpose: Load the model’s pre-trained weights based on given sample_rate value. Also sets the other hparams to how the model was trained.
  • Linear3 purpose: One of the many models that can be selected. Inherits from ModelBase
  • ModelTrainer purpose: Do transfer learning. Can be swapped out so the user can choose whether to train for a multi-class or multi-label task.
  • MyModel purpose: Inherits from both ModelTrainer and Linear3 to construct a “normal” PyTorch Lightning LightningModule.
  • SimpleDataset & SimpleDatamodule purpose: Just to be able to use Trainer and it’s save_checkpoint function.

The following code can also be downloaded from: https://gist.github.com/NumesSanguis/558add315378cda55e28b6d5f63f56b2


Script

CLICK ME to show code

from abc import abstractmethod
import torch
from torch import nn as nn
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy, to_categorical
from pytorch_lightning import Trainer
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F


class ModelBase(pl.LightningModule):
    def __init__(self, pretrained_hparams: bool, **kwargs):  # **kwargs  # sample_rate
        print(f"Init ModelBase, hparams:\n{self.hparams}\n")
        super().__init__()
        print(f"Init ModelBase after, hparams:\n{self.hparams}\n")

        # use PANNs.load_from_checkpoint when loading weights after transfer learning
        if pretrained_hparams:
            # save all arugments in self.hparams
            self.save_hyperparameters()
            print("Argument hparams: ", self.hparams)
            # needed hparams for non-lightning pre-trained weights
            self.set_pretrained_hparams()
            # print("All hparams: ", self.hparams)

    @abstractmethod
    def forward(self, x):
        pass

    def set_pretrained_hparams(self):
        if self.hparams["sample_rate"] == 8000:
            self.hparams["hlayer1"] = 400
        elif self.hparams["sample_rate"] == 16000:
            self.hparams["hlayer1"] = 800
        self.hparams["classes_num"] = 3

    def load_non_lightning_weights(self, weights_path):
        # checkpoint = torch.load(weights_path)
        # self.load_state_dict(checkpoint['model'])
        pass


# 1 variant
class Linear3(ModelBase):
    def __init__(self, sample_rate, **kwargs):
        print(f"Init Linear3, hparams:\n{self.hparams}\n")
        super().__init__(sample_rate=sample_rate, **kwargs)
        print(f"Init Linear3 after, hparams:\n{self.hparams}\n")
        # 1 sec of audio
        self.input_layer = nn.Linear(self.hparams["sample_rate"], self.hparams["hlayer1"], bias=True)
        self.hidden_layer = nn.Linear(self.hparams["hlayer1"], 128, bias=True)
        self.output_layer = nn.Linear(128, self.hparams["classes_num"], bias=True)

    def forward(self, input):
        x = F.relu_(self.input_layer(input))
        x = F.relu_(self.hidden_layer(x))
        output = self.output_layer(x)  # torch.sigmoid()
        return output


class ModelTrainer(pl.LightningModule):
    # arguments should NOT be positional due to inherence; always have a default value
    def __init__(self, learning_rate=1e-3, **kwargs):  # **kwargs
        print(f"Init ModelTrainer, hparams:\n{self.hparams}\n")
        # everything included in init call will be included in self.hparams (here only kwargs is included);
        # meaning only those will be saved in a .ckpt file
        super().__init__(learning_rate=learning_rate, **kwargs)  # **kwargs
        print(f"Init ModelTrainer after, hparams:\n{self.hparams}\n")
        self.criterion = nn.CrossEntropyLoss()

    def calculate_loss(self, prediction, target):
        """Binary crossentropy loss"""
        # loss = F.binary_cross_entropy_with_logits(prediction, target)
        loss = self.criterion(prediction, target)
        return loss

    def training_step(self, batch, batch_idx):
        input, target = batch
        prediction = self(input)
        loss = self.calculate_loss(prediction, target)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss)
        return result

    def validation_step(self, batch, batch_idx):
        input, target = batch
        prediction = self(input)
        loss = self.calculate_loss(prediction, target)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss)
        result.log('val_acc', accuracy(prediction, target))
        return result

    def test_step(self, batch, batch_idx):
        input, target = batch
        prediction = self(input)
        loss = self.calculate_loss(prediction, target)
        result = pl.EvalResult()  # checkpoint_on=loss
        result.log('test_loss', loss)
        result.log('test_acc', accuracy(prediction, target))  # to_categorical()
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams["learning_rate"])


class MyModel(ModelTrainer, Linear3):
    def __init__(self, unfreeze_epoch=1, **kwargs):
        # arguments passed here are stored in self.hparams
        print(f"Init MyModel, hparams:\n{self.hparams}\n")
        super().__init__(unfreeze_epoch=unfreeze_epoch, **kwargs)  # unfreeze_epoch=unfreeze_epoch, **kwargs
        print(f"Init MyModel after, hparams:\n{self.hparams}\n")
        # print("hparams after init: ", self.hparams)

        # self.unfreeze_epoch = unfreeze_epoch
        # self.freeze()

    def forward(self, input, mixup_lambda=None):
        # unfreeze deep layers after unfreeze_epoch epochs
        # if self.current_epoch == self.unfreeze_epoch:
        #     self.unfreeze()

        x = F.relu_(self.input_layer(input))
        x = F.relu_(self.hidden_layer(x))
        output = self.output_layer(x)  # torch.sigmoid()
        return output


# DATA
class SimpleDataset(Dataset):
    def __init__(self, sample_rate=8000):
        self.sample_rate = sample_rate

    def __len__(self):
        return 16

    def __getitem__(self, idx):
        # 0, 1 or 2
        target = torch.randint(0, 3, size=(1, )).squeeze()
        # size 8000/16000 of 0.0, 0.5, or 1.0
        input = torch.full((self.sample_rate,), (target.float()/2).item())
        # torch.empty(self.sample_rate,).fill_(target.float()/2)
        return input, target


class SimpleDatamodule(pl.LightningDataModule):
    def setup(self, stage: str = None):
        pass

    def train_dataloader(self):
        return DataLoader(SimpleDataset(), batch_size=4)

    def val_dataloader(self):
        return DataLoader(SimpleDataset(), batch_size=4)
        # dataset = self._set_dataset_split("val")
        # return DataLoader(dataset, batch_size=self.hparams["batch_size"],
        #                   sampler=SubsetRandomSampler(dataset.indices), num_workers=4)

    def test_dataloader(self):
        return DataLoader(SimpleDataset(), batch_size=4)

if __name__ == '__main__':
    sr = 8000
    checkpoint_location = "example.ckpt"
    # network
    model = MyModel(sample_rate=8000, pretrained_hparams=True)
    print("After all init, hparams:\n{self.hparams}\n")
    # data
    dm = SimpleDatamodule()
    # train
    trainer = Trainer(max_epochs=4, deterministic=True)  # gpus=1,
    trainer.fit(model, dm)
    # save
    trainer.save_checkpoint(checkpoint_location)

    # check model contents
    print(f"\n\nModel save completed. Checking contents saved model...")
    checkpoint = torch.load(checkpoint_location)
    print(f"Checkpoint hyper parameters: {checkpoint['hyper_parameters']}")  # .keys()  # ['state_dict']

    # ERROR: load weights into new model
    print("\nContents check completed. Trying to restore model with checkpoint...")
    model2 = MyModel.load_from_checkpoint(checkpoint_location, pretrained_hparams=False)
    # KeyError: 'sample_rate'

Script output

CLICK ME to show script output

$ python test_save_load.py 
Init MyModel, hparams:


Init ModelTrainer, hparams:


Init Linear3, hparams:


Init ModelBase, hparams:


Init ModelBase after, hparams:


Argument hparams:  "learning_rate":      0.001
"pretrained_hparams": True
"sample_rate":        8000
"unfreeze_epoch":     1
Init Linear3 after, hparams:
"classes_num":        3
"hlayer1":            400
"learning_rate":      0.001
"pretrained_hparams": True
"sample_rate":        8000
"unfreeze_epoch":     1

Init ModelTrainer after, hparams:
"classes_num":        3
"hlayer1":            400
"learning_rate":      0.001
"pretrained_hparams": True
"sample_rate":        8000
"unfreeze_epoch":     1

Init MyModel after, hparams:
"classes_num":        3
"hlayer1":            400
"learning_rate":      0.001
"pretrained_hparams": True
"sample_rate":        8000
"unfreeze_epoch":     1

After all init, hparams:
{self.hparams}

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: GPU available but not used. Set the --gpus flag when calling the script.
  warnings.warn(*args, **kwargs)
/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: Could not log computational graph since the `model.example_input_array` attribute is not set or `input_array` was not given
  warnings.warn(*args, **kwargs)

  | Name         | Type             | Params
--------------------------------------------------
0 | input_layer  | Linear           | 3 M   
1 | hidden_layer | Linear           | 51 K  
2 | output_layer | Linear           | 387   
3 | criterion    | CrossEntropyLoss | 0     
/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 3: 100%|████████████████Saving latest checkpoint..███████████████████████████| 8/8 [00:00<00:00, 103.70it/s, loss=1.651, v_num=28]
Epoch 3: 100%|█████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 102.18it/s, loss=1.651, v_num=28]


Model save completed. Checking contents saved model...
Checkpoint hyper parameters: "classes_num":        3
"hlayer1":            400
"learning_rate":      0.001
"pretrained_hparams": True
"sample_rate":        8000
"unfreeze_epoch":     1

Contents check completed. Trying to restore model with checkpoint...
Init MyModel, hparams:


Init ModelTrainer, hparams:


Init Linear3, hparams:


Init ModelBase, hparams:


Init ModelBase after, hparams:


Init Linear3 after, hparams:


Traceback (most recent call last):
  File "test_save_load.py", line 183, in <module>
    model2 = MyModel.load_from_checkpoint(checkpoint_location, pretrained_hparams=False)
  File "/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)
  File "/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/core/saving.py", line 190, in _load_model_state
    model = cls(*cls_args, **cls_kwargs)
  File "test_save_load.py", line 111, in __init__
    super().__init__(unfreeze_epoch=unfreeze_epoch, **kwargs)  # unfreeze_epoch=unfreeze_epoch, **kwargs
  File "test_save_load.py", line 67, in __init__
    super().__init__(learning_rate=learning_rate, **kwargs)  # **kwargs
  File "test_save_load.py", line 50, in __init__
    self.input_layer = nn.Linear(self.hparams["sample_rate"], self.hparams["hlayer1"], bias=True)
KeyError: 'sample_rate'

Closer look at the issue

From the following code and output

checkpoint = torch.load(checkpoint_location)
print(f"Checkpoint hyper parameters: {checkpoint['hyper_parameters']}")

# Checkpoint hyper parameters:
# "classes_num":        3
# "hlayer1":            400
# "learning_rate":      0.001
# "pretrained_hparams": True
# "sample_rate":        8000
# "unfreeze_epoch":     1

we can see that nothing went wrong with the training and storing of the model, as we do have sample_rate in there.

However the following code

model2 = MyModel.load_from_checkpoint(checkpoint_location, pretrained_hparams=False)

# self.input_layer = nn.Linear(self.hparams["sample_rate"], self.hparams["hlayer1"], bias=True)
# KeyError: 'sample_rate'

fails with a missing key. From the full script output we can also see that, before and after the call to super().__init__(), there are NO values stored in self.hparams.

From the documentation, if you DON’T want to use the values stored in the checkpoint, you would call self.save_hyperparameters(). The loading code overwrites pretrained_hparams=False, which means that this is NOT called:

# not called in loading, because pretrained_hparams is False
if pretrained_hparams:
    # save all arugments in self.hparams
    self.save_hyperparameters()
    print("Argument hparams: ", self.hparams)
    # needed hparams for non-lightning pre-trained weights
    self.set_pretrained_hparams()

Therefore it shouldn’t be an issue of overriding self.hparams.

Question

Why is self.hparams not properly restored when I use .load_from_checkpoint(), even though I avoid self.save_hyperparameters() during training?

System

  • Ubuntu 18.04
  • PyTorch: 1.6.0
  • PyTorch Lightning: 0.9.0 & 0.10.0 (tested both)

Reproduce conda environment

conda create --name pytorchlit10
conda activate pytorchlit10
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
pip install pytorch-lightning  # 0.10.0

lets me see, but it looks like you not calling self.save_hyperparameters() in your Module’s inits
see: https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html?highlight=save_hyperparameters#save-hyperparameters

@jirka
That was indeed the issue! I actually on purpose avoided calling self.save_hyperparameters() when loading the model, because I thought this overwrites the saved hparams in the .ckpt file. I thought this based on the function name & because the docs lists that function after:

But if you don’t want to use the values saved in the checkpoint, pass in your own here

class LitModel(LightningModule):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.save_hyperparameters()
        self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)

The following modification of my code solved the issue:

class ModelBase(pl.LightningModule):
    def __init__(self, pretrained_hparams: bool, **kwargs):  # **kwargs  # sample_rate
        print(f"Init ModelBase, hparams:\n{self.hparams}\n")
        super().__init__()
        self.save_hyperparameters()  # THE LINE BELOW WAS MOVED HERE
        # ^^^^^^^^^^^^^^^^^^^^^^^^^
        print(f"Init ModelBase after, hparams:\n{self.hparams}\n")

        # use PANNs.load_from_checkpoint when loading weights after transfer learning
        if pretrained_hparams:
            # self.save_hyperparameters()  # MOVED THIS LINE OF CODE ^
            # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
            print("Argument hparams: ", self.hparams)
            # needed hparams for non-lightning pre-trained weights
            self.set_pretrained_hparams()

Thank you for checking!