How to switch dataloaders between epochs in Lightning

This question is taken from github:

I have a labeled and an unlabeled dataset that I am using for a semi-supervised segmentation problem.

Here is what I want to implement in the training loop:

  1. Train an epoch on the labeled dataset. (alternatively, a batch)
  2. Train an epoch on the unlabeled dataset. (alternatively, a batch)

What is the best way to do it?

Hi, it is currently not possible to return multiple dataloaders for training (that only works for validation).
A feature for this is in progress here #1959.

However, in your case, I think it is more elegant to do this:

Step 1:
return the right dataloader in each epoch:

def train_dataloader(self):
    if self.current_epoch % 2 == 0:
        labeled_dataloader = ...
        return labeled_dataloader
    else:
        unlabeled_dataloader = ...
        return unlabeled_dataloader

Step 2:
modify your training_step like this:

def training_step(...):
    if self.current_epoch % 2 == 0:
        # apply loss with labels
    else:
        # apply unsupervised loss
    return ...

Step 3:
Finally, tell Trainer to call the train_dataloader method every epoch, so it will switch to the new dataset.

trainer = Trainer(..., reload_dataloaders_every_epoch=True) # False by default
3 Likes

I think, self.current_epoch is not accessible if we use pl.LightningDataModule, Any suggestions?