Loading checkpoints when models built using a 'setup' block

Hello! I’ve seem to run into a situation where the recommended methods to load a model checkpoint fail. In my specific use case I need to use a setup block to avoid building a model immediately. However, this makes loading PTL checkpoints tricky, in general. I’ve found two especially troublesome issues. First, the load_from_checkpoint staticmethod fails. Second, when I manually load weights by first invoking setup followed by a load_state_dict call on my new model, the trainer will wipe the loaded weights when fit is called (I expect by re-running setup).

I’ve come up with workarounds to all of this, but none of them are as pretty as the recommended PTL interfaces, and it seems like a shame not to be able to use them. Otherwise, totally love the Lightning library!

To give you a more concrete example, consider a LightningModule implemented with a setup block:

class MyNet(ptl.LightningModule):
    
    def __init__(self, conv1_width=6, conv2_width=16, 
                 fc1_width=120, fc2_width=84, 
                 dropout1=0.5, dropout2=0.5, 
                 learning_rate=1e-3, **kwargs):
        super().__init__()
                
        self.conv1_width = conv1_width
        self.conv2_width = conv2_width
        self.fc1_width = fc1_width
        self.fc2_width = fc2_width
        self.dropout1 = dropout1
        self.dropout2 = dropout2
        self.learning_rate = learning_rate
        
        self.unused_kwargs = kwargs
        self.save_hyperparameters()
        
    def setup(self, step):
        self.conv1 = nn.Conv2d(3, self.conv1_width, 5)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(self.conv1_width, self.conv2_width, 5)
        self.fc1   = nn.Linear(self.conv2_width * 5 * 5, self.fc1_width)
        self.drop1 = nn.Dropout(p=self.dropout1)
        self.fc2   = nn.Linear(self.fc1_width, self.fc2_width)
        self.drop2 = nn.Dropout(p=self.dropout2)
        self.fc3   = nn.Linear(self.fc2_width, 10)
        
        self.criterion = nn.CrossEntropyLoss()

Then, running:

model = MyNet()
data  = MyDataModule(data_dir='/tmp/pytorch-example/cifar-10-data')

trainer = Trainer(gpus=1, max_epochs=1)
trainer.fit(model, data)

Will train a model. However, when I try to save and load the model like this:

trainer.save_checkpoint('/tmp/pytorch-example/model.ckpt')
new_model = cifar10.MyNet.load_from_checkpoint(checkpoint_path='/tmp/pytorch-example/model.ckpt')

I get a RuntimeError because of unexpected keys in the state_dict, obviously because layer creation relies on the setup block.

This is pretty easy to solve by doing:

my_checkpoint = torch.load('/tmp/pytorch-example/model.ckpt')
new_model = MyNet(my_checkpoint['hyper_parameters'])
new_model.setup('train')
new_model.load_state_dict(my_checkpoint['state_dict'])

However, running

trainer.fit(new_model, data)

resets the model weights by making a fresh call to setup.

My work-around is to modify the __init__ and setup methods in my LightningModule as follows by adding an is_built flag to self:

class MyNet(ptl.LightningModule):
    
    def __init__(self, ...,**kwargs):
        super().__init__()     
        ...
        self.is_built = False
        
    def setup(self, step):
        if not self.is_built:
            ...
            self.is_built = True
        else:
            pass

but I’m not really sure how this is going to behave down the line when I want to start using some of the cooler Lightning features, like parallelization.

Hmm, that looks more like a bug, do you have a full example of the above snaphot ideally in Colab and shot an issue? (pls link the new issue here :slight_smile:

Hey @jirka, thanks for taking a look. I fleshed out the example above in this Colab:

Please let me know if everything makes sense here. Note, cell [5] will error, as expected and shows the recommended load method fails.

Hi, has there been a solution to this? I’m having the same issues. I try to use ModelHooks.setup to initialize the model dynamically but I get the same error above RuntimeError: Error(s) in loading state_dict when loading it from a checkpoint.

There is an issue pending here for this thread. If you have any comments or insight, posting there may help the issue get more traction.

I’m also having problems with this features and it’s even more confusing as it is advertised in the setup() docstring.

Called at the beginning of fit (train + validate), validate, test, and predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
Example:
class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(stage):
        data = Load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)

But when I try to load the checkpoint using LitModel.load_from_checkpoint, then I get for example:

RuntimeError: Error(s) in loading state_dict for LitModel:
	Unexpected key(s) in state_dict: "l1.weight"

OK, I found a solution: The setup() docstring should make a reference to the on_load_checkpoint() hook, so that the following will work:

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        checkpoint["data"] = self.data

    def on_load_checkpoint(self, checkpoint):
        data = checkpoint["data"]
        self.l1 = nn.Linear(28, data.num_classes)

    def setup(stage):
        if self.l1 is None: # training
             data = Load_data(...)
             self.l1 = nn.Linear(28, data.num_classes)
             self.data = data

I use this to load a vocabulary for a LM.

1 Like

this solution worked for me! Thank you.

I not only was dynamically building my model during setup, but also had a bunch of peripheral/auxiliary attributes I needed from setup to load up the model again later. on_save_checkpoint and its loading counterpart was exactly what I needed. In save I passed through whatever I needed to the checkpoint dictionary, and in load I could unpack what I needed from the dictionary and set them as attributes before building the model again

1 Like