Iterable dataset with pytorch lightening

Hi
I have a dataset of type tf.data.dataset, which is iterable, but I cannot access each element of it with random access. I read the documentation here torch.utils.data — PyTorch 1.12 documentation it was not clear for me how I can write my own dataloader with iterable dataset, I appreciate assistance with providing me examples. In particular, here is the tutorial example:

  1. With pytorch lightening, do I need to set worker_info? Is this set automatically?
  2. In case I need, could you tell me how the __iter__ should be written? when in this example it returns iter(range(iter_start, iter_end)) I am not sure how this needs to be done for an iterable dataset?
  3. The dataset needs to get split based on tutorial into multiple workers, how could I know how much workers are avilable in each case of TPU/multiple-gpus/ …
    thanks
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))

Hi Rabeeh,

I am not quite sure what you are trying to do here, but there are a number of things that could be leading to potential issues.

  1. PyTorch Lightning does not automatically set num_workers; typically I recommend setting it to os.cpu_count(). This is done when you instantiate the DataLoader.

  2. IterableDataset should return a python iterator, where each element is an element in your dataset.

  3. The dataset can be split using torch.utils.get_worker_info(), as is done in PyTorch. It is important to remember that Lightning is mostly a lightweight wrapper of PyTorch, so most things can be done as they would be done in normal PyTorch. When training on multiple-CPUs, lightning will handle the splitting of the batches.

Hope this helps :slight_smile:

@Rabeeh_Karimi, did you get this to work? I’m also using an iterable dataset with PyTorch Lightning. I’m suspicious that the issue I’m having here is related.