Saving and loading optimizer state

I am trying to load and fine tune Google’s ViT model using hugging face. I’m trying to use torch.optim.AdamW for optimisation.

How do I load and save state_dict of the optimizer that was defined in configure_optimizers callback?

Thank you.

Hi @thevishnupradeep, model state_dict is available with lit_model.state_dict() method. You can save it using torch.save(...) and load it back with torch.load(...)

Here is an example -

class BoringModel(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        self.lr = 1e-3
        self.model = NeuralNet()
        
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters())

model = BoringModel()

# state available
model.state_dict()

# save the state
torch.save(model.state_dict(), "state.pkl")

# load the state
torch.load("state.pkl")

Also, we have deprecated this forum so I would request you to use our Github Discussions for quicker response. :slight_smile:

This is not what the question is asking. I think the question is asking how we can load an optimizer’s state dict, or how to restore an optimizer.