Loading checkpoints when models built using a 'setup' block

Hello! I’ve seem to run into a situation where the recommended methods to load a model checkpoint fail. In my specific use case I need to use a setup block to avoid building a model immediately. However, this makes loading PTL checkpoints tricky, in general. I’ve found two especially troublesome issues. First, the load_from_checkpoint staticmethod fails. Second, when I manually load weights by first invoking setup followed by a load_state_dict call on my new model, the trainer will wipe the loaded weights when fit is called (I expect by re-running setup).

I’ve come up with workarounds to all of this, but none of them are as pretty as the recommended PTL interfaces, and it seems like a shame not to be able to use them. Otherwise, totally love the Lightning library!

To give you a more concrete example, consider a LightningModule implemented with a setup block:

class MyNet(ptl.LightningModule):
    
    def __init__(self, conv1_width=6, conv2_width=16, 
                 fc1_width=120, fc2_width=84, 
                 dropout1=0.5, dropout2=0.5, 
                 learning_rate=1e-3, **kwargs):
        super().__init__()
                
        self.conv1_width = conv1_width
        self.conv2_width = conv2_width
        self.fc1_width = fc1_width
        self.fc2_width = fc2_width
        self.dropout1 = dropout1
        self.dropout2 = dropout2
        self.learning_rate = learning_rate
        
        self.unused_kwargs = kwargs
        self.save_hyperparameters()
        
    def setup(self, step):
        self.conv1 = nn.Conv2d(3, self.conv1_width, 5)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(self.conv1_width, self.conv2_width, 5)
        self.fc1   = nn.Linear(self.conv2_width * 5 * 5, self.fc1_width)
        self.drop1 = nn.Dropout(p=self.dropout1)
        self.fc2   = nn.Linear(self.fc1_width, self.fc2_width)
        self.drop2 = nn.Dropout(p=self.dropout2)
        self.fc3   = nn.Linear(self.fc2_width, 10)
        
        self.criterion = nn.CrossEntropyLoss()

Then, running:

model = MyNet()
data  = MyDataModule(data_dir='/tmp/pytorch-example/cifar-10-data')

trainer = Trainer(gpus=1, max_epochs=1)
trainer.fit(model, data)

Will train a model. However, when I try to save and load the model like this:

trainer.save_checkpoint('/tmp/pytorch-example/model.ckpt')
new_model = cifar10.MyNet.load_from_checkpoint(checkpoint_path='/tmp/pytorch-example/model.ckpt')

I get a RuntimeError because of unexpected keys in the state_dict, obviously because layer creation relies on the setup block.

This is pretty easy to solve by doing:

my_checkpoint = torch.load('/tmp/pytorch-example/model.ckpt')
new_model = MyNet(my_checkpoint['hyper_parameters'])
new_model.setup('train')
new_model.load_state_dict(my_checkpoint['state_dict'])

However, running

trainer.fit(new_model, data)

resets the model weights by making a fresh call to setup.

My work-around is to modify the __init__ and setup methods in my LightningModule as follows by adding an is_built flag to self:

class MyNet(ptl.LightningModule):
    
    def __init__(self, ...,**kwargs):
        super().__init__()     
        ...
        self.is_built = False
        
    def setup(self, step):
        if not self.is_built:
            ...
            self.is_built = True
        else:
            pass

but I’m not really sure how this is going to behave down the line when I want to start using some of the cooler Lightning features, like parallelization.

Hmm, that looks more like a bug, do you have a full example of the above snaphot ideally in Colab and shot an issue? (pls link the new issue here :slight_smile:

Hey @jirka, thanks for taking a look. I fleshed out the example above in this Colab:

Please let me know if everything makes sense here. Note, cell [5] will error, as expected and shows the recommended load method fails.