Decoupling of Tasks and Models (Use Inheritance or Tasks?)


I am trying to write a few lightning modules for general tasks such as image classification, image classification with fine tuning (multiple ConvNet backbones and classifiers, each implemented as a torch.nn or Lightning module) etc. My understanding is that PyTorch Lighnting provides two options to do this:

One is via inheritance where I write all the boilerplate classification code in a base module and extend that to write specific forward functions while keeping the rest of it the same.

Another option is to create something like ClassificationTask (as shown here) and then pass in torch.nn.Module as input.

Which of these methods is preferable and why? I have a hunch that the latter will give me more flexibility to include pre-existing torch.nn modules but might cause complications in saving/loading the models and the former might fit better in the PyTorch Lightning framework but would not allow for flexible integration but I am not sure.

Also, I am having some difficulty figuring out how to handle hyper parameters in the latter case (I’m assuming I can add a staticmethod in torch.nn module but then I cannot easily save hyper parameters)?

Any guidance will be appreciated



I have actually been working on a prototype via inheritance.

Lightning-hydra-seed on github. But it seems I can t inject mixin in Lightning Module


I’m using the following setup:

class BaseNetwork(pl.LightningModule):
    # network code

class ModelTrainer(pl.LightningModule):
    # model training code

class MyModel(ModelTrainer, BaseNetwork):
    def __init__(self, **kwargs):
        # new code

    # anything that needs to be overwritten

Here is a gist demonstrating the setup I made for an issue:

Which I made for this issue: Hparams not restored when using load_from_checkpoint (default argument values are the problem?)

Make sure you apply the self.save_hyperparameters() fix.

1 Like

I am thinking about similar design. Any lucks on injecting task-mixin into the Lightning Module, or maybe injecting the model-mixin into the Lighting Module?

Thanks @tchaton for your response. Could you please highlight what issues you ran into while using the mixin method?

I was able to train the model using mixin but the built-in self.save_hyperparameters function failed because it was unable to analyze the init function while inheriting from multiple classes. In general I just struggled with hyperparameter related issues.

Thanks @NumesSanguis, this looks useful. I tried a similar method but was unable to save hyperparameters in the child class inherited from multiple parent classes

Hey @usmanshahid @milLII3yway ,

Have a look at our new task framework: GitHub - PyTorchLightning/lightning-flash: Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning..
We would love for you to contribute new tasks.