How to not load complete in-memory dataset for every process in DDP training

I’m working in an environment that has regular HDDs, shared amongst many users. I/O performance is too poor to simply read and parse data on the fly, so I have to load my data in memory.

I have a single node with 4 GPUs (Node resources are not shared, underlying storage is). When training in DDP mode, each process loads the entire dataset in memory, which although works for my current dataset, won’t work for larger ones, and I’d like to avoid that since each process uses only a subset of the data anyway and the rest is redundant. The dataset preparation is done directly in the LightningModule (without explicitly using LightningDataModule)

From my understanding of things, the following should solve this problem:

  • Disable adding of Distributed Sampler in Trainer using replace_sampler_ddp=False
  • pass local rank information to the Dataset and load a particular shard.

So my questions are:

  1. How do I achieve the above, i.e. getting rank information of the process in the LightningModule and passing it on to my dataset object?
  2. Is there a better way to do this using existing pytorch-lightning components?

Solved. All relevant information can be found in the environment variables set by pytorch-lightning when the DDP processes are launched.

Specifically, do the following in your dataset/lightning module definition

import os
env_cp = os.environ.copy()
node_rank, local_rank, world_size = env_cp['NODE_RANK'], env_cp['LOCAL_RANK'], env_cp['WORLD_SIZE']

is_in_ddp_subprocess = env_cp['PL_IN_DDP_SUBPROCESS']
pl_trainer_gpus = enc_cp['PL_TRAINER_GPUS']

Using this info I was able to write shard logic that loaded a specific subset of data in memory for each DDP process.

Leaving here as it might be of help for someone.