How to use batch_sampler in distributed training?

Hello, I have a batch sampler and wish to use it in distributed training. It means that there’s no way for me to know about current epoch to seed shuffling because pl only calls .set_epoch() for samplers and not batch samplers. In the same, time I can’t pass some proxy sampler which could save epoch number somewhere to DataLoader because sampler and batch_sampler are mutually exclusive.

Use batch_sampler → DataLoader requires sampler=None → DataLoader creates Sequential sampler to populate .sampler attribute → Sequential sampler is changed to Distributed sampler and my batch sampler is recreated with some kwargs, which are not suitable for my batch_sampler. It creates an error.

            batch_sampler = type(batch_sampler)(
                sampler,
                batch_size=batch_sampler.batch_size,
                drop_last=(False if is_predicting else batch_sampler.drop_last),
            )

What can I do to use BatchSampler in distributed setting?

Things which might work:

  • Create a sampler-like wrapper around BatchSampler(batch_size=n) which would flatten its output ([[0,1,2], [3,4,5]] → [0, 1, 2, 3, 4, 5]) and then that output would be ones again formed back into batches by default batch_sampler.