`.detach()` cannot stop backprop in `training_step`

When I was implementing GAN with lightning, I first train the discriminator and thought to reuse the generated fake image for training generator. So I used .detach() in the first training step like this:

class DCGAN(pl.LightningModule):
    ...
    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch
        if optimizer_idx == 0:
            # Train discriminator
            z = torch.randn(imgs.shape[0], self.hparams.latent_dim, device=self.device)
            self.generated_imgs = self(z)  # generate fake images
            disc_real = self.disc(imgs)
            disc_fake = self.disc(self.generated_imgs.detach())  # this should have the effect of retaining computational graph
            real_loss = self.criterion(disc_real, torch.ones_like(disc_real))
            fake_loss = self.criterion(disc_fake, torch.zeros_like(disc_fake))
            d_loss = (real_loss + fake_loss) / 2
            self.log('loss/disc', d_loss, on_epoch=True, prog_bar=True)
            return d_loss
       
        if optimizer_idx >= 1:
            # Train generator
            disc_fake = self.disc(self.generated_imgs)
            g_loss = self.criterion(disc_fake, torch.ones_like(disc_fake))
            self.log('loss/gen', g_loss, on_epoch=True, prog_bar=True)
            return g_loss

But I get

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I have found the problem, which is that the graph built by generated_imgs = self(z) was also released in the first training step even with .detach(). (According to @goku, .detach() is not the point here, but the gradient of generator will not be calculated when optimizer_idx == 0.)

However if I instead use plain pytorch or lightning manual backprop (Optimization — PyTorch Lightning 1.1.5 documentation), everything works out fine.

So what happened in the first case?

the reason is when training_step is called with optimizer_idx=0, the parameters passed in other optimizers (in this case it optimizer_idx=1) are turned off in lightning by default.

check this: lightning — PyTorch Lightning 1.1.5 documentation

so when you do self.generated_imgs = self(z), then generator param gradients are turned off.

either you can do self.disc(self(z)) by storing z with optimizer_idx=0 or change the default behaviour by overriding toggle_optimizer.

1 Like

Thanks a lot. the message is really helpful.

I have also tried to update discriminator multiple times and update generator one times in one training step, by configuring the same optimizer multiple times, like this

def configure_optimizers(self):
    opt_d = ...
    opt_g = ...
    return [opt_d] * self.hparams.k + [opt_g], []

This trick just stops training, possibly for the same reason.

But such situation could be common in GAN or other type of networks, while any walking around makes code (just the training part :slight_smile:) harder to write and read than plain pytorch.

So is there a flag or something to turn off this behavior of lightning, so that my code can look more ‘lightning’?

not a flag but you can just do

class LITModule(LightningModule):
    def toggle_optimizer(self, *args, **kwargs):
        pass

to disable this behavior.

Or maybe try manual_optimization: GitHub - PyTorchLightning/pytorch-lightning: The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate.

you can read more about it in the docs. ^^

Oh that’s also pretty neat. I had thought the subclass’s toggle_optimizer would be more complex than just pass.

and yeah, I’m using manual optimization and it works well.

Thank you again for timely and helpful reply. :hugs: