Clarification
This is a great question! Do you think you could share some details about the training/validation/test process?
Is it something like this (in plain PyTorch)?:
datasets = [DatasetOne(), DatasetTwo(), DatasetThree()]
model = Model()
for current_idx in range(len(datasets)):
for epochs in range(epochs_per_dataset):
# train on on current dataset
train_step(model, datasets[current_idx])
# validate on previous datasets
validation_step(model, datasets[:current_idx])
# test on all datasets
test_step(model, datasets)
Potential Solution
I believe you would want to handle this logic in the _dataloader
functions of your LightningModule
:
from torch.utils.data import DataLoader, ChainDataset
class ContinualLearner(LightningModule)
def __init__(self, datasets, epochs_per_dataset):
super().__init__()
# datasets is a list of torch.util.Dataset
self.datasets = datasets
self.curr_index = 0
self.epochs_per_dataset = epochs_per_dataset
def train_dataloader(self):
dl = DataLoader(self.datasets[self.curr_index])
def val_dataloader(self):
return DataLoader(ChainDataset(self.datasets[:self.curr_index]))
def test_dataloader(self):
return DataLoader(ChainDataset(self.datasets))
def on_epoch_end(self):
# update current dataset
if self.trainer.current_epoch % self.epochs_per_dataset == 0:
self.curr_index += 1
To make sure you get the new dataloader every epoch, you will need to use the reload_dataloaders_every_epoch
flag:
trainer = Trainer(reload_dataloaders_every_epoch=True)
trainer.fit(model)
Hope this sets you off in the right direction! Please do not hesitate to ask any other questions
.