Stochastic Weight Averaging

I’m trying to implement SWA from this guide:

I’ve broken the example up as follows:

This bit goes in the __init__ for the LightningModule:

self.swa_model = AveragedModel(
self.swa_start = 5

This goes in configure_optimizers:

self.swa_scheduler = SWALR(optimizer, swa_lr=0.05)

And this bit in train_epoch_end:

if self.trainer.current_epoch > self.swa_start:
    torch.optim.swa_utils.update_bn(self.train_dataloader(), self.swa_model)

I’m getting this error when torch.optim.swa_utils.update_bn is called:

RuntimeError: Expected tensor to have CPU Backend, but got tensor with CUDA Backend (while checking arguments for batch_norm_cpu)

I’m guessing I need to define self.swa_model in such a way it gets put onto the correct device

Is there an example somewhere to use SWA with PL? Thanks!

is AveragedModel not a nn.Module child?
It should already be moved to the correct device automatically.

Thanks for the reply! Yes, you’re right, it is a nn.Module. I printed a parameter from swa_model and it is indeed on the GPU.

The issue was actually the train_dataloader not being on the GPU. There is a device arg in torch.optim.swa_utils.update_bn which fixed the issue.

Hey @Anjum_Sayed,

Would you like to create a PR to SWA ? We can help you out !

Best regards,

Hi @tchaton, what kind of a PR were you thinking of? I think this more user error than an issue with PL. If you have any specific ideas, I’d be happy to help out!