Is Lightning more memory intensive than regular pytorch?

Hi all.

I am trying lightning as it make using DDP more simple than regular pytorch. However, upon launching my training, I get out of memory errors using the same batch size as before. Does Lightning use more memory than vanilla Pytorch or am I missing something?

Below is my code:


class LightningResunet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = resunext(8, 29, 10, 64)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        noise_img, clean_img, _, _ =  batch
        denoised_img = self.model(noise_img)
        loss = F.l1_loss(denoised_img, clean_img)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[20, 200, 500], gamma=0.1)
        return [optimizer], [scheduler]

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.l1_loss(y_hat, y)
        self.log("validation_loss", loss, on_step=True, on_epoch=True, sync_dist=True)


if __name__ == '__main__':
    torch.set_printoptions(linewidth=120)
    now = datetime.now()
    current_time = now.strftime("%H_%M_%S")
    path = "/home/bledc/my_remote_folder/denoiser/models/Apr4_resunet_custom_notpretrain_continue1_{}".format(current_time)
    text_path = path + "/" + current_time + ".txt"

    train_dataset = Syn_noisemaps('/home/bledc/dataset/syn_train2/Syn_train/', 800, 128) + Real('/home/bledc/dataset/SIDD_train/', 320, 128) \
        + Syn_noisemaps('/home/bledc/dataset/MIT/MIT/MIT/mit_all/', 3500, 128) + just_gaussian('/home/bledc/dataset/MIT/MIT/MIT/gaussian/', 500, 128)
    train_size = int(0.95 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, _ = torch.utils.data.random_split(train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

    test_set = test_my_mixed_set("/home/bledc/dataset/my_test_set", 128, patch_size=128)
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=1,
        shuffle=False, num_workers=8,
        pin_memory=True, drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=24,
        shuffle=True, num_workers=8,
        pin_memory=True, drop_last=True)

    trainer = pl.Trainer(max_epochs=1000, gpus=4, strategy="ddp")
    model = LightningResunet()
    trainer.fit(model, train_dataloaders=data_loader,
                val_dataloaders=test_loader)

Thank you!