Do a full pass on the training dataset without grad, and without stepping through the scheduler

Hi,

I’m working in the following sequential learning scenario. I have a learner that receives large chunks of data (datasets) in phases. For each phase, the learner trains on the incoming dataset for a few epochs before training on the next one. Right now, I have implemented the following scenario by

  1. setting my training dataset to be cat([ds_phase_1, ..., ds_phase_n]) and using a custom sampler to only yield the indices of the current phase

  2. incrementing my phase_id when (current_epoch + 1) % epochs_in_phase == 0

Right now this works fine, however I’m unable to do the following thing : every time I start a new phase, I want to compute my loss on the full training set without doing any updates. This is just like measuring your validation loss, the only diff being that it’s on my training dataset (and not on my validation dataset, which I do have too).

Potential solutions

  1. I could just not do any updates in my training_step when (current_epoch + 1) % epochs_in_phase == 1, however I think that my step scheduler would still step() and I want do avoid this. Moreover, I would like to use a larger batch size and be in eval mode to save GPU memory.

  2. Somehow temporarily overwrite my validation dataset with my training dataset, and call something like self.validate() ? Is there such a method that takes in a dataloader and loops through it ?

Please let me know what’s the easiest way to proceed :slight_smile:

This is the solution I went with.
Instead of piggybacking on the trainer to do different things at different epochs, I wrapped the FitLoop to account for several rounds

    class MetaFitLoop(FitLoop):
        # Wraps the regular fit loop to handle data dumps over time
        def __init__(self, trainer, cfg):
            super().__init__(max_epochs=cfg.num_epochs)

            self.cfg = cfg
            self.trainer = trainer
            self.trainer.current_round = 0

            # To compute the online accuracy
            self.eval_loop = EvaluationEpochLoop()
            self.train_ds_fetcher = DataFetcher()

        def run(self):
            while self.trainer.current_round < self.cfg.num_rounds:
                if self.trainer.current_round > 0: 
                    # start by online eval of training set
                    with torch.no_grad():
                        self.eval_loop.run(self.train_ds_fetcher, len(train_dl), OrderedDict())
                else:
                    train_dl = self.trainer.datamodule.train_dataloader()
                    self.train_ds_fetcher.setup(train_dl) 

                super(FitLoop, self).run()
                _reset_progress(self.epoch_progress)
                self.trainer.current_round += 1

Note that no matter what dataset I pass into the EvaluationEpochLoop it always ends up being treated a training dataset and calling training_step (this is why I had to manually disable the gradient). I would greatly appreciate feedback or any suggestions on how to make this better and more pytorch-lightning-y :slight_smile:

Hi @pclucas14, I think you can disable automatic optimization and update the optimizer on the training_epoch_end hook.

Here is a link that might help you - Optimization — PyTorch Lightning 1.7.4 documentation

https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#training

Also, we recently migrated to Github discussion for community question answers.

1 Like