What's the best practice for continual learning?

For continual learning, we have several tasks, each of which has one dataset. We need to train a model on these tasks sequentially. There are some requirements:

  1. In training process, we train model on current dataset.
  2. In validation process, we check the performance on all previous datasets.
  3. In testing process, we test our model on all the datasets.

Hope for any suggestions. Thanks!


This is a great question! Do you think you could share some details about the training/validation/test process?

Is it something like this (in plain PyTorch)?:

datasets = [DatasetOne(), DatasetTwo(), DatasetThree()]
model = Model()

for current_idx in range(len(datasets)):
    for epochs in range(epochs_per_dataset):

        # train on on current dataset
        train_step(model, datasets[current_idx])

        # validate on previous datasets
        validation_step(model, datasets[:current_idx])

# test on all datasets
test_step(model, datasets)

Potential Solution

I believe you would want to handle this logic in the _dataloader functions of your LightningModule:

from torch.utils.data import DataLoader, ChainDataset

class ContinualLearner(LightningModule)
    def __init__(self, datasets, epochs_per_dataset):

        # datasets is a list of torch.util.Dataset
        self.datasets = datasets
        self.curr_index = 0
        self.epochs_per_dataset = epochs_per_dataset

    def train_dataloader(self):
        dl = DataLoader(self.datasets[self.curr_index])

    def val_dataloader(self):
        return DataLoader(ChainDataset(self.datasets[:self.curr_index]))

    def test_dataloader(self):
        return DataLoader(ChainDataset(self.datasets))

    def on_epoch_end(self):
        # update current dataset
        if self.trainer.current_epoch % self.epochs_per_dataset == 0:
            self.curr_index += 1

To make sure you get the new dataloader every epoch, you will need to use the reload_dataloaders_every_epoch flag:

trainer = Trainer(reload_dataloaders_every_epoch=True)

Hope this sets you off in the right direction! Please do not hesitate to ask any other questions :slight_smile:.


Thanks a lot for replying! That’s almost what I want. There is another thing I want to add. Hopefully, I want to check the performance on each dataset, not the concatenated one. That is, I want the evaluation process to be like this:

def valid_or_test(datasets: List[Dataset]):
    results = SomeCollection()
    for dataset in datasets:
        result = evaluate_model_on_dataset(dataset)
    // return or do something with the result collection
    return results

In conclusion, I need the evaluation process to be performed on each dataset separately and provide me with the result on each one. I wonder if I can do it elegantly with pytorch-lightning? Looking forward to your suggestion soon!

This should be remedied by lightning’s multiple data loaders functionality. Using this your results will be indexed by dataset.

    def val_dataloader(self):
        return [DataLoader(ds) for ds in self.datasets[:self.curr_index]]

    def test_dataloader(self):
        return [DataLoader(ds) for ds in self.datasets]

In this case validation_epoch_end and test_epoch_end will be passed a List[List[Any]], where the first list contains a list of outputs for each dataset. Any result you return from validation_step/test_step will be accessible here.

1 Like

Oh, I omitted this functionality in docs. Thanks a lot for reminding me!

No worries! Happy to help