Why is trainer not included in the module?

The big benefit of lightning is to package everything needed for training a model in one place. So why is the trainer kept separate so there are actually two things required rather than one? The explanation in the docs says engineering code in trainer, model in module. However this seems an artificial distinction and not always consistent e.g. training_step is in model yet the training loop is in trainer.

If there is some reason to keep them separate then why is it that model gets passed to the trainer fit? Is there some circumstance where one might fit the same trainer to multiple models? If not then why is model not passed to the trainer in init?

So why is the trainer kept separate so there are actually two things required rather than one?

Abstraction. Keep the engineering part in the Trainer and research part in the LightningModule. training_step defines what your model should do but the training loop is the engineering part that handles how it should be done.

Is there some circumstance where one might fit the same trainer to multiple models?

yes.

1 Like

Thanks. When would you fit the same trainer to multiple models?

not the same trainer instance but maybe to the same trainer configuration. Something like, you can define your trainer configuration and load it every time you need to try out a different model for maybe an image classification task. It doesn’t make sense to put model in Trainer init in this case.

In that case wouldn’t you create a new trainer instance each time? So why not pass it to init rather than having to pass it every time you call fit. I can’t see why you would call fit muliple times with different models because the state would be already set from the previous model.

Here is my attempt to adapt the LightningModule in this example to sklearn estimator-like semantics:

I added set_trainer, fit, evaluate interfaces to the LightningModule and they are used like this:

# In function cli_main
# Initialization
model = LitClassifier(args.hidden_dim, args.learning_rate)
trainer = pl.Trainer.from_argparse_args(args)
model.set_trainer(trainer)  # trainer is loosely attached to model so I wouldn't pass it in the init method

# Training + Validation
model.fit(train_loader, val_loader)    

# Test      
result = model.evaluate(test_loader)

These LightningModule methods are implemented as follows:

def set_trainer(self, trainer):
    self.trainer = trainer
def fit(self, train_loader, val_loader):
    self.trainer.fit(self, train_loader, val_loader)   
def evaluate(self, test_loader):
    result = self.trainer.test(test_dataloaders=test_loader)
    return result

The modified code actually works.

I am actually quite interested what could go wrong with this implementation. I am worried the twisted logic of the fit method might create some potential issues. I think it will not form a loop because the trainer is unaware the model has a fit method. But I don’t know whether it is also safe in other situations. For example, when saving models, will the status of the trainer also be saved?

I had similar workaround which worked previously but as with any workaround is it may suddenly break and did in latest version with a pickling error. Also maybe I am missing some hidden feature where there is a benefit of passing different models to each fit. If so I would like to know what it is.

class Trainer(pl.Trainer):
    def __init__(self, model, **kwargs):
        super().__init__(**kwargs)
        self.model = model

    def fit(self, *args, **kwargs):
        if args and args[0] is self.model:
            super().fit(*args, **kwargs)
        else:
            super().fit(self.model, *args, **kwargs)
1 Like

Agreed, I guess I didn’t think it through at that time. Yeah, won’t make sense to call the trainer.fit multiple times. Let’s just say it’s a design choice to put the model in trainer.fit along the dataloaders/datamodule.