I am looking for a way to reinitialized my Datamodule with different parameter, I am currently sending the height of my images as argument to my datamodule and I want to change this height at some point during training, the simple way is to call trainer.fit multiple times with different datamodules, but I am wondering is there a way to do this on callback, in the same way as you do when you change the optimizer or lr_scheduler?
Hello in such case you need to force update used logger, I would consider adding a reset/update method to your data module which would be eventually called from Model/hook or callback…
Or shall we add also more hooks to the data module as the model has? @nate @teddy
I have done this using a callback:
class Scheduler(pl.Callback): def _prepare_epoch(self, trainer, model, epoch): phase = ... trainer.datamodule.set_phase(phase) def on_epoch_end(self, trainer, model): self._prepare_epoch(trainer, model, trainer.current_epoch + 1) class Data(pl.LightningDataModule): def set_phase(self, phase: dict): self.size = phase.get("size", self.size) train_transforms = T.Compose( [ T.RandomResizedCrop(self.size, scale=(self.min_scale, 1.0)), T.RandomHorizontalFlip(), T.ToTensor(), normalize, ] ) self.train_ds = ImageFolder(self.train_dir, transform=train_transforms) def train_dataloader(self): train_dl = DataLoader( self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, ) return train_dl
Its important to note:
- You can access your datamodule from a callback using
- In order to have
val_dataloader()called every epoch, you must set
reload_dataloaders_every_epoch=Truein your trainer.