Cannot load hyper parameters properly from a checkpoint

Hi, I am new to PyTorch Lightning, and now I am testing checkpointing because I cannot finish a training session before GPU resource timeout (12 hours).

I am logging accuracy, loss, and learning rate using TensorBoardLogger.

From TensorBoard, I found that my source code does not properly load the model or trainer from the checkpoint. In the screenshots, version_0 (orange) is the trace of the initial training, and version_1 (blue) is the trace of the resumed training. As you can see, the resumed training starts from epoch 0 and the previous learning rate is not loaded and initialized as 0.01. If I am understanding correctly, all those states should be stored in the checkpoint and resumed as they were stored, but that’s not in my case.

Here is my source code with example save paths that I am using in it:

file_name = get_filename(hparams)        
save_path = hparams.save_path or os.path.join(os.getcwd(), 'logs', hparams.dataset)
ckpt_path = None if hparams.v_num is None else os.path.join(save_path, file_name, 'version_' + str(hparams.v_num),'checkpoints')
ckpt_file = None if hparams.v_num is None else os.path.join(ckpt_path, os.listdir(ckpt_path)[0])
statedict_path = os.path.join(os.getcwd(), 'trained_models', hparams.dataset + '_' + hparams.arch + '.pt')

''' 
### Path Examples:
file_name = MODEL_NAME
save_path = HOME_DIR/logs/cifar10
ckpt_path = HOME_DIR/logs/cifar10/MODEL_NAME/version_0/checkpoints
ckpt_file = HOME_DIR/logs/cifar10/MODEL_NAME/version_0/checkpoints/epoch=6-step=1243.ckpt
statedict_path = HOME_DIR/trained_models/MODEL_NAME.pt
'''
    
if hparams.v_num is None:
    model = MODEL(hparams, 
                  num_classes = dm.num_classes, 
                  train_size = len(dm.train_dataloader().dataset))
else:
    model = MODEL.load_from_checkpoint(ckpt_file,
                                       num_classes = dm.num_classes, 
                                       train_size = len(dm.train_dataloader().dataset))

logger = TensorBoardLogger(save_path, name=file_name)
lr_monitor = LearningRateMonitor(logging_interval='step')
    
trainer = Trainer(default_root_dir=save_path,
                  gpus=hparams.gpus,   
                  max_epochs=hparams.epochs,   
                  resume_from_checkpoint = ckpt_file,
                  distributed_backend=hparams.distributed_backend, 
                  num_nodes=hparams.num_nodes,
                  logger = logger, 
                  callbacks=[lr_monitor,],
                  deterministic = deterministic,) 

trainer.fit(model, dm)

# Save weights from checkpoint
torch.save(model.model.state_dict(), statedict_path)

FYI, hparams.v_num is specified as an int value only when resuming from checkpoint.

I am sure that I am not using the methods properly, and the save paths might also be incorrect. However, I could not find any good example to solve my issue with.

Please let me know if you can find any mistakes that I made from my source code…

Hey, I’m very new to PyTorch Lightning, but have you check this part of the docs out? It seems to go into some detail about checkpointing.

However, what I do not understand is why you need to load the model separately from the checkpoint? Does passing the resume_from_checkpoint flag to the Trainer not load all the states (e.g., step, epoch, optimizer state, model state) (it should, according to the doc). I would try to manually query all the optimizer states before/after loading and see what is not being restored correctly, maybe raise an issue on this if a particular state is not being saved/restored.

1 Like

mind reproduce this issue with https://colab.research.google.com/drive/1HvWVVTK8j2Nj52qU4Q4YCyzOm0_aLQF3?usp=sharing??

1 Like

Thank you for the comment! I’m new to this checkpointing stuff since I recently started training a large model. I will definitely take a look at the document and figure out how to do the checkpointing correctly. By the way, it seems that my issue is a recognized one to the community: https://github.com/PyTorchLightning/pytorch-lightning/issues/730

I hope this issue can be resolved in the next milestone.

Guess using the colab to reproduce it will help the developers iron it out quickly!