Best practises for implementing large datasets with DDP

I am attempting to improve data transfer speed for a GNN. It is currently implemented using pytorch Lightning. Our total dataset size is large (15 million graphs ~ 60Gb on disk, spread between ~ 200 files).

I was able to get it working as expected using an IterableDataset, by allowing each worker to either load in all their designated files at initialization (using persistent workers), or loading in a set of files each time __iter__ was called. But I wasn’t able to get this to work correctly across multiple GPUs, using more than 1 GPU just resulted in a significant slow down of training.
I also tried creating a normal Dataset, where all files were loaded into a dictionary/array (tried both), but when using multiple workers this results in huge memory usage (the docs suggest that the total memory usage tends to num_workers * dataset_size, which is roughly what I experienced) that was unsustainable.

The final thing I tried was to manually pass the world size/current device rank to my iterable data set. But when I added some debug code in, it seemed as though all workers were only ever accessing the dataset with rank=0.

Is there a way to implement DDP to work correctly with iterable datasets - it seems its not directly supported but I wondered if there was a workaround