Validation sanity check hangs after `all_gather`

I am trying to move all validation outputs to one process to calculate my metric. My code looks something like the follows

At each validation step

def validation_step(self, batch, batch_idx):
        # forward
        outputs = self.forward(batch)
        y_pred = outputs.logits
        y_true = batch["rbd_labels"]
        # loss
        loss = F.cross_entropy(y_pred, y_true)
        self.log("valid_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        instances = batch["instances"]
        labels = batch["rbd_labels"]
        scores = y_pred[:, 1]
        return (instances, labels, scores)

Then I call self.all_gather at the output

def validation_epoch_end(self, validation_step_outputs):
        # Step 1: collect instances and labels
        instance_label_score_map = defaultdict(list)
        instances = []
        labels = []
        scores = []
        for i, l, s in validation_step_outputs:
            instances += i.tolist()
            labels += l.tolist()
            scores += s.tolist()

        out_instances = self.all_gather(instances)
        out_labels = self.all_gather(labels)
        out_scores = self.all_gather(scores)


        if dist.get_rank() == 0:
            print("dist rank: 0")
            # Note: they stack in a weird way so we need to convert it back
            out_instances = torch.stack(out_instances).cpu().tolist()
            out_labels = torch.stack(out_labels).cpu().tolist()
            out_scores = torch.stack(out_scores).cpu().tolist()
            
            score = compute_score()
            self.log("valid_score", score)

Then my programs hangs right after the sanity check after the validation_epoch_end.
Appreciate any help!

Turns out one need to specify rank_zero_only=True in self.log as there is some synchronization going on in the backend.

That is not a solution. We should be able to calculate metrics from all processes. In my experience, it is only the validation sanity check that fails. While training on DDP, the whole trainina and eval loops work fine. I just have to turn off the sanity check, because only that seems to fail… just hangs on the logging sync.