How frequently are train_dataloader and val_dataloader called?

How frequently are train_dataloader and val_dataloader called? Are they done every epoch? If so, this is problematic when you have short epochs and long data loading times as whenever you recreate the dataloader you have to synchronously wait for that first batch to load before a model step can be performed.

It appears that the validation loader gets created once for the initial check, and then once again for actual validation whereas the training loader only get created once.

This is correct. The validation loader is created for the initial checks and later on recreated for actual training. The trainloader is created for training purposes only (because this is usually way more data to load then for validation).

EDIT: Unfortunately there is no way around creating the validation loader twice, since otherwise we cannot guarantee that there are always all batches used for validation in the first epoch and we still need some data to do all the checks. Also the data can come in various forms, so we just use the one from val_loader, which is user-defined and thus should match the expected format and type.

Just for completeness:
There is also the reload_dataloaders_every_epoch flag which does exactly what it’s name implies if specified :slight_smile:

https://pytorch-lightning.readthedocs.io/en/stable/trainer.html#reload-dataloaders-every-epoch

A way to prevent long data loadings to be done repeatedly is to use a data module, which already loads the dataset during init like this and later just creates the loader on the fly which just wraps the already existing dataset:

class MyFancyDataModule(LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.train_ds, self.valid_ds, self.test_ds = self.create_datasets(*args, **kwargs)

    def create_datasets(self, *args, **kwargs):
        # DO YOUR DATASET CREATION LOGIC HERE
        return train_ds, valid_ds, test_ds

    def train_dataloader(self):
        return DataLoader(self.train_ds)

    def val_dataloader(self):
        return DataLoader(self.valid_ds)

This has the advantage that you always load your data only once, but the disadvantage that you also load all your trainset before running the checks (annoying during debugging).

You can however overcome this issue, by just caching the datasets the first time they were loaded and reuse the datasets. Just make sure not to reuse the data loader :slight_smile:

Another tip I found browsing the github issues is to prevent worker threads from being respawned each epoch, if you have short epochs this makes a big difference.

# Originally proposed by  PetrochukM in https://github.com/pytorch/pytorch/issues/15849#issuecomment-518126031
# Modified by monoelh in https://github.com/PyTorchLightning/pytorch-lightning/issues/2875#issuecomment-673355304

class _RepeatSampler(object):
    """ Sampler that repeats forever.
    Args:
        sampler (Sampler)
    """
    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)

class ContinuousDataLoader(torch.utils.data.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._DataLoader__initialized = False
        self.batch_sampler = _RepeatSampler(self.batch_sampler)
        self._DataLoader__initialized = True
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)

See
https://github.com/pytorch/pytorch/issues/15849 and https://github.com/PyTorchLightning/pytorch-lightning/issues/2875 for more details.

1 Like

While there is definitely a point in this, IMO it is cleaner to have the processes killed at the end of each epoch, since otherwise you’re having twice the number of processes running (assuming same number of workers for train and eval). But I also understand your point :slight_smile: