I have a pretrained
torch.nn.Module that my
LightningModule uses for training.
For the purpose of example assume it is a pretrained & fixed ResNet image model, that I use for feature generation.
How can I best use such a module from my
Simply storing it as a child module:
self.resnet = ResNet()
Would result in its parameters being stored as part of the
LightningModule and increase the checkpoint size. Also, this approach prevents a single ResNet model from being shared by multiple modules. This is a big issue for huge models.
Pass the pretrained model as a parameter
class MyModel(LightningModule): def __init__(self, resnet: ResNet): self._resnet = [resnet]
With this approach the ResNet model is not really owned by the LightningModule, and simply stored as a reference. It allows model sharing and does not store it inside the checkpoint. But the problem is the device management. I need to manually mode the resnet via
.cuda() and the problem is even greater when training on multiple GPUs.
Is there a better option - that stores the model as an attribute for automatic device management, but that does not manage weights?