Decoupling of Tasks and Models (Use Inheritance or Tasks?)


I am trying to write a few lightning modules for general tasks such as image classification, image classification with fine tuning (multiple ConvNet backbones and classifiers, each implemented as a torch.nn or Lightning module) etc. My understanding is that PyTorch Lighnting provides two options to do this:

One is via inheritance where I write all the boilerplate classification code in a base module and extend that to write specific forward functions while keeping the rest of it the same.

Another option is to create something like ClassificationTask (as shown here) and then pass in torch.nn.Module as input.

Which of these methods is preferable and why? I have a hunch that the latter will give me more flexibility to include pre-existing torch.nn modules but might cause complications in saving/loading the models and the former might fit better in the PyTorch Lightning framework but would not allow for flexible integration but I am not sure.

Also, I am having some difficulty figuring out how to handle hyper parameters in the latter case (I’m assuming I can add a staticmethod in torch.nn module but then I cannot easily save hyper parameters)?

Any guidance will be appreciated



I have actually been working on a prototype via inheritance.

Lightning-hydra-seed on github. But it seems I can t inject mixin in Lightning Module


I’m using the following setup:

class BaseNetwork(pl.LightningModule):
    # network code

class ModelTrainer(pl.LightningModule):
    # model training code

class MyModel(ModelTrainer, BaseNetwork):
    def __init__(self, **kwargs):
        # new code

    # anything that needs to be overwritten

Here is a gist demonstrating the setup I made for an issue:

Which I made for this issue: Hparams not restored when using load_from_checkpoint (default argument values are the problem?)

Make sure you apply the self.save_hyperparameters() fix.