What is the best way to train on stages?

What is currently the best/correct way to train on stages? I mean, for example, that at first I train model with one loss, and after some time change it to different one. Or change augmentations.

There is an issue about this, but it is stale: https://github.com/PyTorchLightning/pytorch-lightning/issues/2006

I see three possible ways:

  1. Write a callback, which will check training status in on_epoch_start (or using some other method) and change parameters of lightning module. But I’m not sure it is safe. For example, will it be okay if I change dataloaders using such callback?
  2. Wait until training is finished, then load the checkpoint and continue. But I’m not sure how to change optimizers/dataloaders and other things in this case.
  3. Save model directly (pytorch pth file), then create a new instance of Lightning Module with new parameters and load the weights from file. Then train.

Which of these approaches is better, or maybe there is a different better way to do this?

I think the best option as of now is 2. Model.load_from_checkpoint() accepts additional args that override the ones loaded from the checkpoint. In this way you could do something like:

model = Model(stage1)
trainer.fit(model) # stage1

model.stage = 2 # OR
model = Model.load_from_checkpoint("stage1.ckpt", stage=2)
trainer.fit(model) # stage2
...
class Model(LightningModule):
    # anything you would like to change can use self.stage
    def configure_optimizers(self):
        if self.stage == 1:
             return ...
        elif self.stage == 2:
             return ...

Hopefully this helps! If this doesn’t seem like the best way perhaps you can share your exact use case and we can figure it out :slight_smile:

3 Likes

should it be .train or .fit?

My bad, fixed. Thanks for pointing out!

For the changing dataloaders point, if you use a Datamodule, you could possibly add additional stages aside from ‘fit’ and ‘test’ and then call dm.setup(stage='fit_stage2') with a callback when you are that appropriate point in your training procedure. :thinking:

1 Like

wouldn’t trainer call datamodule.setup('fit') again in such a case??