What's the best way to do two generator steps?

Dear Pytorch lightning community,
I am using Pytorch lightning to train a GAN. Thus, for each training_step I have one generator_step and one discriminator_step (similar to lightning-bolts/basic_gan_module.py at f48357be353b7acdd882379ac3308fbec95dc40d · Lightning-AI/lightning-bolts · GitHub).

I need to do two steps of the discriminator_step, thus I have to backward two times for each batch. What’s the best way to do it in pytorch lighting?

thank you

you mean for a single batch you want to do:

# Generator
gen_loss.backward()
gen_opt.step()
gen_opt.zero_grad()

# Discriminator
disc_loss.backward()
disc_opt.step()
disc_opt.step()
disc_opt.zero_grad()

or something else?

I would like to do this

Generator

gen_loss.backward()
gen_opt.step()
gen_opt.zero_grad()

Discriminator

disc_loss.backward()
disc_opt.step()
disc_opt.zero_grad()
disc_loss2.backward()
disc_opt.step()
disc_opt.zero_grad()

See manual optimization. This should allow you to do exactly what you want!

1 Like

I think you might get an error here if you use the old weights to calculate disc_loss2 because first disc_opt.step() will update the weights and disc_loss2.backward() will be calculating the gradients using the new weights.

1 Like

@goku thank you, it was a wrong example indeed but @teddy solution works!! Thank you all