Changing Datamodule during training

I am looking for a way to reinitialized my Datamodule with different parameter, I am currently sending the height of my images as argument to my datamodule and I want to change this height at some point during training, the simple way is to call trainer.fit multiple times with different datamodules, but I am wondering is there a way to do this on callback, in the same way as you do when you change the optimizer or lr_scheduler?

Hello in such case you need to force update used logger, I would consider adding a reset/update method to your data module which would be eventually called from Model/hook or callback…
Or shall we add also more hooks to the data module as the model has? @nate @teddy

I have done this using a callback:

class Scheduler(pl.Callback):
    def _prepare_epoch(self, trainer, model, epoch):
        phase = ... 
        trainer.datamodule.set_phase(phase)

    def on_epoch_end(self, trainer, model):
        self._prepare_epoch(trainer, model, trainer.current_epoch + 1)

class Data(pl.LightningDataModule):
    def set_phase(self, phase: dict):
        self.size = phase.get("size", self.size)
        train_transforms = T.Compose(
            [
                T.RandomResizedCrop(self.size, scale=(self.min_scale, 1.0)),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                normalize,
            ]
        )
        self.train_ds = ImageFolder(self.train_dir, transform=train_transforms)

       
    def train_dataloader(self):
        train_dl = DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )
        return train_dl

Its important to note:

  1. You can access your datamodule from a callback using trainer.datamodule
  2. In order to have train_dataloader(), val_dataloader() called every epoch, you must set reload_dataloaders_every_epoch=True in your trainer.
1 Like

Is the proposed solution still valid ?

In my codebase, trainer.datamodule is None.