I created a model where I pass the argparse arguments(args) as well as a neural network and a loss function. After training, I try to load a checkpoint, so I thought I just need to pass the model checkpoint, neural network and the loss function as inputs to load_from_checkpoint(), since I already saved the hyperparameters(contained in the args). However, it throws an error unless I pass args.
args = parser.parse_args()
class SSLModel(pl.LightningModule): def __init__(self, args, network, loss): super().__init__() """ args: argument parser, required parameters for training the model network: nn.Module, consists of an encoder and the projection head loss: nn.Module, Loss used for training the model. """ self.save_hyperparameters(args) self.network = network self.loss = loss
Trained using the following commands:
learning_module = SSLModel(args= args, network=network, loss=loss) trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=[early_stopping, checkpoint_callback,lr_monitor, custom_callback]) trainer.fit(learning_module, train_dataloader, valid_dataloader)
After training, I try to load a checkpoint as given below, but it throws an error(‘args missing’) unless I also pass the args as input
model = SSLModel.load_from_checkpoint(ckpt_file, network=network, loss=l)
I assumed that I need to pass only
loss since they are not saved hyperparameters.
Also, is there a way to escape even passing the
loss and save them as hyperparams?