Access module's optimizers in a callback

Context:
I have a use case where I need to access the Lightning Module’s optimizers in a callback. The motivation is to call a function in on_save_checkpoint() to update the optimizer state dict before dumping to a checkpoint. I want to do this in a callback as opposed to the lightning module’s on_save_checkpoint() as this functionality is specifically tied to the optimizer, which could be used across a number of lightning modules.

Question:
Is trainer.train_loop.get_optimizers_iterable() the right API to use for this purpose? https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_loop.py#L548

In your usecase you can use trainer.optimizers, since trainer.train_loop.get_optimizers_iterable() is somewhat related to optimizer frequencies, so you will get only one optimizer in that case if you have multiple optimizers with some frequencies defined.

1 Like

TIL about trainer.optimizers. That’s even simpler