Adopting exponential moving average (EMA) for PL pipeline


I wonder which would be the correct way to implement the EMA step with lightning.

Example attempt,

*from torch_ema import ExponentialMovingAverage*    
class SomeModel(pl.LightningModule):
    def __init__(self):

        self.criterion = SomeLoss()
        self.encoder = encoder()
        self.head = nn.Sequential(...)
        self.ema = ExponentialMovingAverage(self.encoder.parameters(), decay=0.995)

    def forward(self, x):
        return logit
    def training_step(self, batch, batch_idx):
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=(1e-3) * 3)
        scheduler = {'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader), T_mult=1, eta_min=0, last_epoch=-1, verbose=False), 'interval': 'step'}
        return [optimizer], [scheduler]

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False,         using_native_amp=False, using_lbfgs=False):

  1. Do optimizer_step override configure_optimizers and intended schedules?
  2. Which is the best way to modify optimizer.step, does my implementation make sense?

Thank you in advance.

I checked the repo, it should be done after optimizer.step()

so two ways to do this:

def optimizer_step(self, *args, **kwargs):
    super().optimizer_step(*args, **kwargs)

or just override on_before_zero_grad hook, no need to touch optimizer_step :slight_smile: :

def on_before_zero_grad(self, *args, **kwargs):
1 Like