Logging Images during validation using Tensorboard Logger

Hi,

I have successfully implemented the method to log images to tensorboard logger, except I run out of GPU memory soon as I accumulate images during the whole validation_step and by end of the validation round, I randomly select few images to log. This is not the best way to do it. Can someone point me how to do it properly where I don’t consume too much of Memory. Thanks.

This is how my validation step looks:

def validation_step(self, batch, batch_idx):
        imgs, y_true = batch
        y_pred = self(imgs)
        val_loss = self.nn_criterion(y_pred, y_true)
        self.log("val_loss", val_loss)
        
        return {"val_loss": val_loss,
                "images": imgs,
                "masks_pred": y_pred,
                "true_masks": y_true}

One can clearly see that I am accumulating tensors over the validation step. Since I am working with very large dataset, I run out of memory very soon. Thanks in advance.

If I can somehow access a single batch of validation inside validation_epoch_end that should solve the problem actually.

you can save the batch inside the state variable and access it?

def validation_step(self, batch, batch_idx):
    self.val_batch = batch
    ...

def validation_epoch_end(self, outputs):
    self.val_batch # access it here

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only after some time.

Thank you

Hi @goku,

You can quite easily solve this issue by logging to tensorboard during validation_step.

    def validation_step(self, batch: Any, batch_idx: int):

        imgs, y_true = batch
        y_pred = self(imgs)
        val_loss = self.nn_criterion(y_pred, y_true)
        self.log("val_loss", val_loss)

        if batch_idx % 10: # Log every 10 batches
            self.log_tb_images((imgs, y_true, y_pred, batch_idx))

        return loss

    def log_tb_images(self, viz_batch) -> None:
         
         # Get tensorboard logger
         tb_logger = None
         for logger in self.trainer.loggers:
            if isinstance(logger, pl_loggers.TensorBoardLogger):
                tb_logger = logger.experiment
                break

        if tb_logger is None:
                raise ValueError('TensorBoard Logger not found')
       
        # Log the images (Give them different names)
        for img_idx, (image, y_true, y_pred, batch_idx) in enumerate(zip(*viz_batch)):
            tb_logger.add_image(f"Image/{batch_idx}_{img_idx}", image, 0)
            tb_logger.add_image(f"GroundTruth/{batch_idx}_{img_idx}", y_true, 0)
            tb_logger.add_image(f"Prediction/{batch_idx}_{img_idx}", y_pred, 0)
2 Likes