On_batch_end callback distributed printing

In a distributed setting, how would you print batch_end results of all processes but only on rank zero?

Notice how each processes has a different result copy. I am using the spawn backend.

class MyCallback(pl.Callback):
    def __init__(self):
        self.results = {}

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self.results[batch_idx] = batch_idx
        print("BATCH END", rank_zero_only.rank, self.results)
        """
        BATCH END 1 {3: 3, 1: 1}
        BATCH END 0 {2: 2, 4: 4, 0: 0} 
        """

    @rank_zero_only
    def on_test_end(self, trainer, pl_module):
        print("TEST END", rank_zero_only.rank, self.results)
        """
        TEST END 0 {2: 2, 4: 4, 0: 0} 
        
        but I want:

        TEST END 0 {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} 
        """

I used to print inside of on_batch_end for all processes but some lines were getting joined by what I assume is a race condition. (Haven’t been able to reproduce it with toy examples)

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        print(batch_idx)
0
3
42
1

I used to print inside of on_batch_end for all processes but some lines were getting joined by what I assume is a race condition. (Haven’t been able to reproduce it with toy examples)

For some reason, there is no issue with print(str(batch_idx) + '\n', end=''). Absolutely no idea why :exploding_head: