Save/load model for inference

I’m trying to understand how I should save and load my trained model for inference

Lightning allows me to save checkpoint files, but the problem is the files are quite large because they contain a lot of information that is not relevant to inference

Instead, I could do torch.save(model.state_dict(), "model.pt"), which I believe only contains the trained weights, and then load the model using:

model = FullModel()
model.load_state_dict(torch.load("model.pt"))
model.eval()

My problem here is that my FullModel class takes in a config dict, which was used to tune hyperparameters during training:

TypeError: __init__() missing 1 required positional argument: 'config'

Is the way around this to save config to disk during training, and load that up with the model during inference? Or is there a more “correct” way of doing it?

I could simply save the entire model (and not just the state_dict), which really simplifies loading, but that file ends up almost as big as the checkpoint files

you can set save_weights_only=True in ModelCheckpoint which will save the hparams and model.state_dict().
https://pytorch-lightning.readthedocs.io/en/latest/generated/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint

ahh thank you!

and that would then get loaded with model.load_from_checkpoint(), and no need to save a separate config file to disk?

yep, hparams are saved even when save_weights_only=True.