My Training Loss and Validation loss are correct but my validation loss is exploding

Hi Folks,

I have moved my code from Pytorch to PyTorch Lightning recently to facilitate the implementation of distributed data parallel training.

As per the title, my validation loss is exploding to the order of 10e+10 while using L1 loss. This occurs around epoch 20. I am confused as I am also reporting PSNR on my validation dataset and the values are not great but acceptable (~30 dB). My model is a UNET (encoder-decoder) with ~58M parameters. The training loss does not show this behaviour.

Is it possible that it is overfitting to my training data? I have successfully trained models a fraction of this size prior without issue acheiving better results. My validation set is a subset of my training set. I’ve also ran an integrity check on my validation images and nothing is corrupted.

Below is my lightning code:

class LightningResunet(pl.LightningModule):
    def __init__(self):
        self.model = resunetplus(3)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        noise_img, clean_img, _, _ = batch
        denoised_img = self.model(noise_img)
        loss = F.l1_loss(denoised_img, clean_img)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[20, 150, 500], gamma=0.1)
        return [optimizer], [scheduler]

    def validation_step(self, batch, batch_idx):
        noisy, clean = batch
        denoised = self.model(noisy)
        val_loss = F.l1_loss(denoised, clean)
        psnr = test_psnr(clean, denoised)
        self.log("validation_loss", val_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True, logger=True)
        self.log("validation_psnr", psnr, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True, logger=True)
        return {'val_loss': val_loss, 'val_psnr': psnr}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([torch.as_tensor(x['val_loss']) for x in outputs]).mean()
        avg_psnr = torch.stack([torch.as_tensor(x['val_psnr']) for x in outputs]).mean()
        self.log('avg_val_loss', avg_loss, on_epoch=True, sync_dist=True, prog_bar=True)
        self.log('avg_val_psnr', avg_psnr, on_epoch=True, sync_dist=True, prog_bar=True)

if __name__ == '__main__':
    now =
    current_time = now.strftime("%H_%M_%S")
    path = "/home/bledc/denoiser/models/Apr14_singleconnection_{}".format(current_time)
    text_path = path + "/" + current_time + ".txt"

    train_dataset = Syn_noisemaps('/home/bledc/datasets/Syn_train/', 800, 96) + Real('/home/bledc/datasets/SIDD_train/', 320, 96) \
        + Syn_noisemaps('/home/bledc/datasets/mit_all/', 3500, 96) + just_gaussian('/home/bledc/datasets/gaussian/', 500, 96)
    train_size = int(0.95 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, _ =, [train_size, val_size], generator=torch.Generator().manual_seed(42))

    test_set = test_my_mixed_set("/home/bledc/datasets/my_test_set/", 96, patch_size=96)
    test_loader =
        test_set, batch_size=1,
        shuffle=False, num_workers=8,
        pin_memory=True, drop_last=True)

    data_loader =
        train_dataset, batch_size=16,
        shuffle=True, num_workers=8,
        pin_memory=True, drop_last=False)

    chk_path = "/home/bledc/denoiser/models/Apr14_noconnections_resunet_02_30_56/model_E_epoch=2657-validation_psnr=30.16.ckpt"
    checkpoint_callback = ModelCheckpoint(

    resume_from_checkpoint=chk_path, callbacks=[checkpoint_callback])
    trainer = pl.Trainer(max_epochs=5000, gpus=4, strategy="ddp", callbacks=[checkpoint_callback], accumulate_grad_batches=2, gradient_clip_val=0.2, resume_from_checkpoint=chk_path)
    model = LightningResunet(), train_dataloaders=data_loader,

Any suggestions are welcome! Thanks folks!

I suggest that you go back to Pytorch :slight_smile:
I had the same problem (I got also overfitting, with every possible change I made in the code) since the last update (1.6). I am not sure what is going on but when I got back to Pytorch, my code worked just fine :slight_smile:

I am seeing the same behavior. With identical forward passes, my model in PyTorch shows great decaying train and validation loss, but when I train the exact same model in Lightning, after approximately 20 epochs, the training and validation loss both explode by a factor of 1e15.