I am trying to write a few lightning modules for general tasks such as image classification, image classification with fine tuning (multiple ConvNet backbones and classifiers, each implemented as a torch.nn or Lightning module) etc. My understanding is that PyTorch Lighnting provides two options to do this:
One is via inheritance where I write all the boilerplate classification code in a base module and extend that to write specific forward functions while keeping the rest of it the same.
Another option is to create something like ClassificationTask (as shown here) and then pass in torch.nn.Module as input.
Which of these methods is preferable and why? I have a hunch that the latter will give me more flexibility to include pre-existing torch.nn modules but might cause complications in saving/loading the models and the former might fit better in the PyTorch Lightning framework but would not allow for flexible integration but I am not sure.
Also, I am having some difficulty figuring out how to handle hyper parameters in the latter case (I’m assuming I can add a staticmethod in torch.nn module but then I cannot easily save hyper parameters)?
Any guidance will be appreciated