The closure passed to the optimizer is None when using fp16

Hi,

I’m trying to use a new optimizer called Sharpness-Aware Minimization (SAM) in Lightning. This optimizer does two backward passes for every forward pass.

The SAM implementation I’m using can be found here. I don’t have any issues when running in fp32, but when using fp16 I’ve encountered the following error.

Traceback (most recent call last):                                                                                                                                                                [60/4905]
  File "pipe/train.py", line 253, in <module>                                                                                                                                                              
    train_score, valid_score = run(hparams=hparams)                                                                                                                                                        
  File "pipe/train.py", line 190, in run                                                                                                                                                                   
    trainer.fit(model, dm)                                                                                                                                                                                 
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 473, in fit                                                                          
    results = self.accelerator_backend.train()                                                                                                                                                             
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 66, in train                                                            
    results = self.train_or_test()                                                                                                                                                                         
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 69, in train_or_test                                                        
    results = self.trainer.train()                                                                                                                                                                         
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 524, in train                                                                        
    self.train_loop.run_training_epoch()                                                                                                                                                                   
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 572, in run_training_epoch                                                     
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)                                                                                                                               
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 730, in run_training_batch                                                     
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)                                                                                                                    
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 513, in optimizer_step                                                         
    using_lbfgs=is_lbfgs,                                                                                                                                                                                  
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py", line 1261, in optimizer_step                                                               
    optimizer.step(closure=optimizer_closure)                                                                                                                                                              
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 286, in step                                                                          
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)                                                                                                                   
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 140, in __optimizer_step                                                              
    trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)                                                                                                                        
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/plugins/native_amp.py", line 82, in optimizer_step                                                             
    trainer.scaler.step(optimizer)                                                                                                                                                                         
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 321, in step                                                                                
    retval = optimizer.step(*args, **kwargs)                                                                                                                                                               
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/torch/optim/lr_scheduler.py", line 67, in wrapper                                                                                
    return wrapped(*args, **kwargs)
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/gianluca/git/kaggle/ranzcr/src/ml/optim.py", line 107, in step
    loss = closure().detach()
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
TypeError: 'NoneType' object is not callable

It seems to complain that closure is None. In your experience, what could cause that? Is there anything generally I should do to ensure that an optimizer working on fp32 does also work on fp16?

Many thanks!

Looks like the problem is in the implementation.

it expects closure can’t be None, but in Lightning when using the native amp, closure is called before optimizer.step()
should be done this way:

In case of native amp closure is called outside the optimizer.step in lightning. That might be the reason here for the failure.

looks like closures aren’t supported with 16bit precision training.
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.step

I should have updated this thread! You are correct. We need to also ensure the closure is not None here.

Mmm… that’s interesting! I need to run some tests, but I suspect that might be the reason why the optimizer is failing when using precision=16. When using precision=32, the closure is never None.

To add to the confusion, it seems that when passing precision=16 to the Trainer we use PyTorch autocast (link) and not actually set the model to 16-bit precision.