Proper way to log things when using DDP

Hi, I was wondering what is the proper way of logging metrics when using DDP. I noticed that if I want to print something inside validation_epoch_end it will be printed twice when using 2 GPUs. I was expecting validation_epoch_end to be called only on rank 0 and to receive the outputs from all GPUs, but I am not sure this is correct anymore. Therefore I have several questions:

  1. validation_epoch_end(self, outputs) - When using DDP does every subprocess receive the data processed from the current GPU or data processed from all GPUs, i.e. does the input parameter outputs contains the outputs of the entire validation set, from all GPUs?
  2. If outputs is GPU/process specific what is the proper way to calculate any metric on the entire validation set in validation_epoch_end when using DDP?

I understand that I can solve the printing by checking self.global_rank == 0 and printing/logging only in that case, however I am trying to get a deeper understanding of what I am printing/logging in this case.

Here is a code snippet from my use case. I would like to be able to report f1, precision and recall on the entire validation dataset and I am wondering what is the correct way of doing it when using DDP.

def _process_epoch_outputs(self,
                           outputs: List[Dict[str, Any]]
                           ) -> Tuple[torch.Tensor, torch.Tensor]:
    """Creates and returns tensors containing all labels and predictions

    Goes over the outputs accumulated from every batch, detaches the
    necessary tensors and stacks them together.

        outputs (List[Dict])
    all_labels = []
    all_predictions = []

    for output in outputs:
        for labels in output['labels'].detach():

        for predictions in output['predictions'].detach():

    all_labels = torch.stack(all_labels).long().cpu()
    all_predictions = torch.stack(all_predictions).cpu()

    return all_predictions, all_labels

def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:
    """Logs f1, precision and recall on the validation set."""

    if self.global_rank == 0:
        print(f'Validation Epoch: {self.current_epoch}')

    predictions, labels = self._process_epoch_outputs(outputs)
    for i, name in enumerate(self.label_columns):

        f1, prec, recall, t = metrics.get_f1_prec_recall(predictions[:, i],
                                                         labels[:, i],

        if self.global_rank == 0:
            print((f'F1: {f1}, Precision: {prec}, '
                   f'Recall: {recall}, Threshold {t}'))