Validation sanity check hangs after `all_gather`

I am trying to move all validation outputs to one process to calculate my metric. My code looks something like the follows

At each validation step

def validation_step(self, batch, batch_idx):
        # forward
        outputs = self.forward(batch)
        y_pred = outputs.logits
        y_true = batch["rbd_labels"]
        # loss
        loss = F.cross_entropy(y_pred, y_true)
        self.log("valid_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        instances = batch["instances"]
        labels = batch["rbd_labels"]
        scores = y_pred[:, 1]
        return (instances, labels, scores)

Then I call self.all_gather at the output

def validation_epoch_end(self, validation_step_outputs):
        # Step 1: collect instances and labels
        instance_label_score_map = defaultdict(list)
        instances = []
        labels = []
        scores = []
        for i, l, s in validation_step_outputs:
            instances += i.tolist()
            labels += l.tolist()
            scores += s.tolist()

        out_instances = self.all_gather(instances)
        out_labels = self.all_gather(labels)
        out_scores = self.all_gather(scores)

        if dist.get_rank() == 0:
            print("dist rank: 0")
            # Note: they stack in a weird way so we need to convert it back
            out_instances = torch.stack(out_instances).cpu().tolist()
            out_labels = torch.stack(out_labels).cpu().tolist()
            out_scores = torch.stack(out_scores).cpu().tolist()
            score = compute_score()
            self.log("valid_score", score)

Then my programs hangs right after the sanity check after the validation_epoch_end.
Appreciate any help!

Turns out one need to specify rank_zero_only=True in self.log as there is some synchronization going on in the backend.