Passing args when loading checkpoint


I created a model where I pass the argparse arguments(args) as well as a neural network and a loss function. After training, I try to load a checkpoint, so I thought I just need to pass the model checkpoint, neural network and the loss function as inputs to load_from_checkpoint(), since I already saved the hyperparameters(contained in the args). However, it throws an error unless I pass args.

args = parser.parse_args()
class SSLModel(pl.LightningModule):
    def __init__(self, args, network, loss):
        args: argument parser, required parameters for training the model
        network: nn.Module, consists of an encoder and the projection head
        loss: nn.Module, Loss used for training the model.  
        self.save_hyperparameters(args) = network
        self.loss = loss

Trained using the following commands:

learning_module = SSLModel(args= args, network=network, loss=loss)
trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=[early_stopping, checkpoint_callback,lr_monitor, custom_callback]), train_dataloader, valid_dataloader)

After training, I try to load a checkpoint as given below, but it throws an error(‘args missing’) unless I also pass the args as input

model = SSLModel.load_from_checkpoint(ckpt_file, network=network, loss=l)

I assumed that I need to pass only network and loss since they are not saved hyperparameters.

Also, is there a way to escape even passing the network and loss and save them as hyperparams?

hey @anupsingh15

would you mind sharing a reproducible script and raising an issue here.

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only soon.

Thank you