How to use the closure functions for optimization

Hi,

I’m trying to use an optimizer that requires using the closure functions for optimization, as explained in the documentation. Unfortunately, I’m having some troubles using the suggested implementation.

The optimizer implementation can be found here.

The recommended way to use this optimizer using plain PyTorch is:

from sam import SAM
...

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...

for input, output in data:
  def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

  loss = loss_function(output, model(input))
  loss.backward()
  optimizer.step(closure)
  optimizer.zero_grad()
...

This is how my LightningModule looks like:

class ImageClassifier(pl.LightningModule):

    ...

    def forward(self, x):
        x = self.model(torch.as_tensor(data=x))
        return x

    def configure_optimizers(self):
        config = dict()
        optimizer = optimizer_factory(
            params=self.parameters(), hparams=self.hparams
        )
        config["optimizer"] = optimizer

        if True:
            config["lr_scheduler"] = lr_scheduler_factory(
                optimizer=optimizer,
                hparams=self.hparams,
                data_loader=self.train_dataloader(),
            )
            config["monitor"] = "valid_metric"
        return config

    def optimizer_step(
        self,
        current_epoch,
        batch_nb,
        optimizer,
        optimizer_idx,
        second_order_closure,
        on_tpu=False,
        using_native_amp=False,
        using_lbfgs=False,
    ):
        def second_order_closure(
            pl_module, split_batch, batch_idx, opt_idx, optimizer, hidden
        ):
            # Model training step on a given batch
            result = pl_module.training_step(
                split_batch, batch_idx, opt_idx, hidden
            )

            # Model backward pass
            pl_module.backward(result, optimizer, opt_idx)

            # on_after_backward callback
            pl_module.on_after_backward(
                result.training_step_output, batch_idx, result.loss
            )

            return result

        # update params
        optimizer.step(closure=second_order_closure)

Unfortunately, I get this error message.

Traceback (most recent call last):                                                                                                                                                                         
  File "pipe/train.py", line 252, in <module>                                                                                                                                                              
    train_score, valid_score = run(hparams=hparams)                                                                                                                                                        
  File "pipe/train.py", line 189, in run                                                                                                                                                                   
    trainer.fit(model, dm)                                                                                                                                                                                 
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 473, in fit                                                                          
    results = self.accelerator_backend.train()                                                                                                                                                             
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 66, in train
    results = self.train_or_test()
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 69, in train_or_test
    results = self.trainer.train()
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 524, in train
    self.train_loop.run_training_epoch()
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 572, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 730, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 513, in optimizer_step
    using_lbfgs=is_lbfgs,
  File "/home/gianluca/git/kaggle/ranzcr/src/ml/classification.py", line 91, in optimizer_step
    optimizer.step(closure=second_order_closure)
File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 286, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 144, in __optimizer_step
    optimizer.step(closure=closure, *args, **kwargs)
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/torch/optim/lr_scheduler.py", line 67, in wrapper
    return wrapped(*args, **kwargs)
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/gianluca/git/kaggle/ranzcr/src/ml/optim.py", line 114, in step
    self.first_step(zero_grad=True)
  File "/home/gianluca/miniconda3/envs/ranzcr/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/gianluca/git/kaggle/ranzcr/src/ml/optim.py", line 78, in first_step
    grad_norm = self._grad_norm()
  File "/home/gianluca/git/kaggle/ranzcr/src/ml/optim.py", line 126, in _grad_norm
    for group in self.param_groups
RuntimeError: stack expects a non-empty TensorList

Have I misunderstood how to use the second-order closure example?

Hello, my apology for the late reply. We are slowly converging to deprecate this forum in favor of the GH build-in version… Could we kindly ask you to recreate your question there - Lightning Discussions