When I was implementing GAN with lightning, I first train the discriminator and thought to reuse the generated fake image for training generator. So I used .detach() in the first training step like this:
class DCGAN(pl.LightningModule): ... def training_step(self, batch, batch_idx, optimizer_idx): imgs, _ = batch if optimizer_idx == 0: # Train discriminator z = torch.randn(imgs.shape, self.hparams.latent_dim, device=self.device) self.generated_imgs = self(z) # generate fake images disc_real = self.disc(imgs) disc_fake = self.disc(self.generated_imgs.detach()) # this should have the effect of retaining computational graph real_loss = self.criterion(disc_real, torch.ones_like(disc_real)) fake_loss = self.criterion(disc_fake, torch.zeros_like(disc_fake)) d_loss = (real_loss + fake_loss) / 2 self.log('loss/disc', d_loss, on_epoch=True, prog_bar=True) return d_loss if optimizer_idx >= 1: # Train generator disc_fake = self.disc(self.generated_imgs) g_loss = self.criterion(disc_fake, torch.ones_like(disc_fake)) self.log('loss/gen', g_loss, on_epoch=True, prog_bar=True) return g_loss
But I get
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I have found the problem, which is that the graph built by (According to @goku,
generated_imgs = self(z) was also released in the first training step even with
.detach() is not the point here, but the gradient of generator will not be calculated when
optimizer_idx == 0.)
However if I instead use plain pytorch or lightning manual backprop (Optimization — PyTorch Lightning 1.1.5 documentation), everything works out fine.
So what happened in the first case?