Stochastic Weight Averaging

I’m trying to implement SWA from this guide:

I’ve broken the example up as follows:

This bit goes in the __init__ for the LightningModule:

self.swa_model = AveragedModel(self.net)
self.swa_start = 5

This goes in configure_optimizers:

self.swa_scheduler = SWALR(optimizer, swa_lr=0.05)

And this bit in train_epoch_end:

if self.trainer.current_epoch > self.swa_start:
    self.swa_model.update_parameters(self.net)
    self.swa_scheduler.step()
    torch.optim.swa_utils.update_bn(self.train_dataloader(), self.swa_model)

I’m getting this error when torch.optim.swa_utils.update_bn is called:

RuntimeError: Expected tensor to have CPU Backend, but got tensor with CUDA Backend (while checking arguments for batch_norm_cpu)

I’m guessing I need to define self.swa_model in such a way it gets put onto the correct device

Is there an example somewhere to use SWA with PL? Thanks!

is AveragedModel not a nn.Module child?
It should already be moved to the correct device automatically.

Thanks for the reply! Yes, you’re right, it is a nn.Module. I printed a parameter from swa_model and it is indeed on the GPU.

The issue was actually the train_dataloader not being on the GPU. There is a device arg in torch.optim.swa_utils.update_bn which fixed the issue.

Hey @Anjum_Sayed,

Would you like to create a PR to SWA ? We can help you out !

Best regards,
T.C

Hi @tchaton, what kind of a PR were you thinking of? I think this more user error than an issue with PL. If you have any specific ideas, I’d be happy to help out!

Hi, do you mean you need to do update_bn(..., device="cpu")?
It would be good if lightning can handle gpu automatically as it’s within its design principles