How to switch from optimizer during training

Original question from

Is possible to show how we should write the “configure_optimizers” and “training_step” functions for the following code.
The purpose of the code is to switch the optimizer from LBFGS to Adam when the loss_SUM<0.3

optimizer = optim.LBFGS(model.parameters(), lr=0.003)
Use_Adam_optim_FirstTime=True
Use_LBFGS_optim=True

for epoch in range(30000):
    loss_SUM = 0
    for i, (x, t) in enumerate(GridLoader):
        x = x.to(device)
        t = t.to(device)
        if Use_LBFGS_optim:
          def closure():
            optimizer.zero_grad()
            lg, lb, li = problem_formulation(x, t, x_Array,t_Array,bndry,pi)
            loss_total=lg+ lb+ li
            loss_total.backward(retain_graph=True)
            return loss_total
          loss_out=optimizer.step(closure)
          loss_SUM+=loss_out.item()
        elif Use_Adam_optim_FirstTime:
          Use_Adam_optim_FirstTime=False
          optimizerAdam = optim.Adam(model.parameters(), lr=0.0003)
          model.load_state_dict(checkpoint['model'])
          optimizerAdam.zero_grad()
          lg, lb, li = problem_formulation(x, t, x_Array,t_Array,bndry,pi)
          lg.backward()
          lb.backward()
          li.backward()
          optimizerAdam.step()
          loss_SUM += lg.item()+lb.item()+li.item()
        else:
          optimizerAdam.zero_grad()
          lg, lb, li = problem_formulation(x, t, x_Array,t_Array,bndry,pi)
          lg.backward()
          lb.backward()
          li.backward()
          optimizerAdam.step()
          loss_SUM += lg.item()+lb.item()+li.item()  
    if loss_SUM<.3 and use_LBFGS_optim == True:
      Use_LBFGS_optim=False
      checkpoint = {'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()}

Hi
In your lightning module, you could do this:

def on_epoch_start(self):
    if self.loss_SUM > 0.3
        self.trainer.optimizers[0] = Adam(...)

and you start with LBFGS as default, returned in configure_optimizers.

Original answer from @goku :rocket:

I think this logic can now better be done in configure_optimizers itself in case someone has some crazy schedulers, or schedulers_dict as well and calling:

def on_epoch_start(self):
    if condition:
        self.trainer.accelerated_backend.setup_optimizers(self)

def configure_optimizers(self):
    if condition:
        return Adam(...)
    else:
        return LBFGS(...)

I am getting the following error :
self.trainer.accelerated_backend.setup_optimizers(self)
AttributeError: ‘Trainer’ object has no attribute ‘accelerated_backend’

my bad, should be accelerator_backend.

I am using the following script to switch between optimizers, however getting this error:
" self.trainer.accelerator_backend.setup_optimizers(self)
AttributeError: ‘CPUBackend’ object has no attribute ‘setup_optimizers’"

def configure_optimizers(self):
      if self.current_epoch % 2 == 0:
            return optim.Adam(self.parameters(), lr=self.hparams.lr)
       else:
            return optim.SGD(self.parameters(), lr=self.hparams.lr)

def on_epoch_start(self):
    self.trainer.accelerator_backend.setup_optimizers(self)

try the latest release.

@Peyman_Poozesh did it work?

yes, after updating the pytorch-lightening. Thanks

Do you know how to use LFBGS optimizer in pytorch-lightening (issue # 3672)? I am trying to switch from ADAM to lightening using:

def on_epoch_start(self):
    if condition:
        self.trainer.accelerator_backend.setup_optimizers(self)
def configure_optimizers(self):
    if condition:
        return Adam(...)
    else:
        return LBFGS(...)

are you getting any error or something when you are using it?

following the below script, it seems that LBFGS optimizer is not updating the weights, in fact the loss is not changing. I have no problem using Adam and SGD optimizer in Pytorch-lightening, however I do not know how to use LBFGS.

def configure_optimizers(self):
    optimizer = optim.LBFGS(self.parameters(), lr=0.01)
   return optimizer

def training_step(self, train_batch, batch_idx):
    x, t = train_batch
    lg, lb, li = self.problem_formulation(x, t, self.bndry)
    loss = lg + lb + li
    return {'loss': loss}

def backward(self, trainer, loss, optimizer, optimizer_idx):
    loss.backward(retain_graph=True)

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure, 
                 on_tpu=False, using_native_amp=False, using_lbfgs=True):
    optimizer.step(second_order_closure)

I tried your exact code with a basic MNIST example and the weights are updating.

Can you check if this function is correct? Although as you said it’s working with Adam, so I am not sure what might be wrong here.

this is the code runs with Adam and not with LBFGS. we try to solve PDE equation using DL. the Problem formation part of the code includes the loss functions.

from IPython.display import clear_output
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.autograd import grad
import pytorch_lightning as pl
from argparse import Namespace


# u_t=u_xx
# u(t,0)=0
# u(t,1)=1
# u(0,x)=(2*x)/(1+x^2)

class MyData(Dataset):
    def __init__(self, startX, stopX, startT, stopT, NumberOfSteps):
        super(MyData, self).__init__()
        x = torch.linspace(startX, stopX, NumberOfSteps, dtype=torch.float32, requires_grad=True)
        t = torch.linspace(startT, stopT, NumberOfSteps, dtype=torch.float32, requires_grad=True)
        gx, gt = torch.meshgrid(x, t)
        self.x = gx.contiguous().view(-1, 1)
        self.t = gt.contiguous().view(-1, 1)

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, index):
        return self.x[index], self.t[index]


class MyModel(pl.LightningModule):
    def __init__(self, hparams):
        super(MyModel, self).__init__()
        self.hparams = hparams
        self.Grid = None
        self.loss_fn = nn.MSELoss()
        self.bndry = torch.tensor([0., 1.], dtype=torch.float32, requires_grad=True).view(-1, 1).cuda()
        self.fc1 = nn.Linear(2, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 20)
        self.fc5 = nn.Linear(20, 20)
        self.fc6 = nn.Linear(20, 20)
        self.fc7 = nn.Linear(20, 1)
        self.T = nn.Tanh()

    def forward(self, X, T):
        x = torch.cat((X, T), dim=1)
        x = self.T(self.fc1(x))
        x = self.T(self.fc2(x))
        x = self.T(self.fc3(x))
        x = self.T(self.fc4(x))
        x = self.T(self.fc5(x))
        x = self.T(self.fc6(x))
        x = self.fc7(x)
        return x

    def prepare_data(self):
        self.Grid = MyData(0, 1, 0, 2, 200)

    def train_dataloader(self):
        return DataLoader(dataset=self.Grid, batch_size=self.hparams.batch_size, shuffle=True)


    def configure_optimizers(self):
      optimizer = optim.LBFGS(self.parameters(), lr=0.01)
      return optimizer

    
    def training_step(self, train_batch, batch_idx):
        x, t = train_batch
        lg, lb, li = self.problem_formulation(x, t, self.bndry)
        loss = lg + lb + li
        return {'loss': loss}

    def backward(self, trainer, loss, optimizer, optimizer_idx):
        loss.backward(retain_graph=True)


    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure, on_tpu=False, using_native_amp=False, using_lbfgs=True):
        # update params
        optimizer.step(second_order_closure)


    def training_epoch_end(self, outputs):
        clear_output(wait=True)
        sum_total_loss = torch.stack([x['loss'] for x in outputs]).sum()
        print('Epoch={}, total_loss={:.3f}'.format(self.current_epoch, sum_total_loss.item()))
        # return {'sum_total_loss': sum_total_loss}

    def on_train_epoch_end(self):
        fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
        self.plot_pcolor(0, 1, 0, 2, 100, self.device, fig, ax)
        self.Plot_InitialCondition(0, 1, 100, self.device, ax)

    def Plot_InitialCondition(self, xStart, xStop, NUM, device, ax):
        x = torch.linspace(xStart, xStop, NUM, dtype=torch.float32, requires_grad=True, device=self.device).view(-1, 1)
        t = torch.zeros_like(x)
        out = self.forward(x, t)
        x = x.detach().cpu()
        out = out.detach().cpu()
        ax[1].plot(x, out)
        plt.ylim(0, 1.2)
        ax[1].axvline(x=1, color='r')
        ax[1].axhline(y=1, color='r')
        plt.show()

    def plot_pcolor(self, xStart, xStop, tStart, tStop, NUM, device, fig, ax):
        x = torch.linspace(xStart, xStop, NUM, dtype=torch.float32, requires_grad=True, device=self.device)
        t = torch.linspace(tStart, tStop, NUM, dtype=torch.float32, requires_grad=True, device=self.device)
        xg, tg = torch.meshgrid(x, t)
        X = xg.clone()
        T = tg.clone()
        X = X.view(-1, 1)
        T = T.view(-1, 1)
        z = self.forward(X, T)
        z = z.view(NUM, -1)
        xg = xg.detach().cpu()
        tg = tg.detach().cpu()
        z = z.detach().cpu()
        j = ax[0].pcolormesh(xg, tg, z, cmap='jet')
        fig.colorbar(j, ax=ax[0])
        plt.colorbar(j)

    def problem_formulation(self, x, t, bndry):
        u = self.forward(x, t)
        u_x = grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True, allow_unused=True)[0]
        u_xx = grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True, allow_unused=True)[0]
        u_t = grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True, allow_unused=True)[0]
        lossGrid = self.loss_fn(u_xx, u_t)

        bndryExpand = torch.ones_like(t) * bndry[0]
        u0 = self.forward(bndryExpand, t)
        lossB0 = self.loss_fn(u0, torch.zeros_like(u0, dtype=torch.float32))
        bndryExpand = torch.ones_like(t) * bndry[1]
        u1 = self.forward(bndryExpand, t)
        lossB1 = self.loss_fn(u1, torch.ones_like(u1, dtype=torch.float32))

        bndryExpand = torch.ones_like(x) * bndry[0]
        uInit = self.forward(x, bndryExpand)
        lossIni = self.loss_fn(uInit, (2 * x) / (1 + (x ** 2)))

        return lossGrid, lossB0 + lossB1, lossIni


def main(hparams):
    model = MyModel(hparams)
    # tb_logger = pl.loggers.TensorBoardLogger(hparams.TensorBoard_path,name=hparams.TensorBoard_FileName)
    trainer = pl.Trainer(fast_dev_run=hparams.fast_dev_run,
                         max_epochs=hparams.max_epochs,
                         gpus=hparams.gpus,
                         accumulate_grad_batches=hparams.accumulate_grad_batches,
                         # limit_train_batches=hparams.limit_train_batches,
                         progress_bar_refresh_rate=hparams.progress_bar_refresh_rate,
                         #                     logger=tb_logger,
                         #                     logger=False,
                         checkpoint_callback=hparams.checkpoint_callback)
    trainer.fit(model)


if __name__ == '__main__':
    args = {'root': '/content/drive/Shared drives/Poozesh/facades/train',
            'fast_dev_run': False,
            'max_epochs': 50,
            'gpus': 1,
            # 'limit_train_batches':4,
            'accumulate_grad_batches': 5,
            'progress_bar_refresh_rate': 0,
            'checkpoint_callback': False
            'lr': 0.001,
            'batch_size': 1024,
            'num_workers': 1,
            }
    hparams = Namespace(**args)
    main(hparams)

I tried you code with learning_rate=1e-5, seems to be updating the loss and model weights as well.

I tried the following code on MNIST dataset. it seems that the loss is updating for a couple of steps and then stops. I was just wondering if you can check this code and if possible, please share with me the code that you tried on MNIST using LBFGS optimizer.

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms,datasets
from torch.utils.data import DataLoader,random_split
import pytorch_lightning as pl 
from IPython.display import clear_output

class LightningMNISTClassifier(pl.LightningModule):
  def __init__(self):
    super(LightningMNISTClassifier,self).__init__()
    self.layer_1 = nn.Linear(28 * 28, 128)
    self.layer_2 = nn.Linear(128, 256)
    self.layer_3 = nn.Linear(256, 10)
    
  def forward(self, x):
    batch_size, channels, width, height = x.size()
    x=x.view(batch_size,-1)
    # layer 1
    x = self.layer_1(x)
    x = torch.relu(x)
    # layer 2
    x = self.layer_2(x)
    x = torch.relu(x) 
    # layer 3
    x = self.layer_3(x)
    # probability distribution over labels
    x = torch.log_softmax(x, dim=1)  
    return x 
  def prepare_data(self):
    transform=transforms.Compose([transforms.ToTensor(), 
                                  transforms.Normalize((0.1307,), (0.3081,))])
    # prepare transforms standard to MNIST
    mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
    mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)  
    self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

  def train_dataloader(self):
    return DataLoader(self.mnist_train,batch_size=1024)
 
  # def val_dataloader(self):
  #   return DataLoader(self.mnist_val,batch_size=1024)
  # def test_dataloader(self):
  #   return DataLoader(self.mnist_test,batch_size=1024)


  def configure_optimizers(self):
    # optimizer=optim.Adam(self.parameters(),lr=1e-3)
    optimizer = optim.LBFGS(self.parameters(), lr=1e-4)
    return optimizer

  # def backward(self, trainer, loss, optimizer):
  #   loss.backward(retain_graph=True)


  def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx,
                     second_order_closure, on_tpu=False, using_native_amp=False,
                     using_lbfgs=False):
        # update params
      optimizer.step(second_order_closure) 

  def cross_entropy_loss(self,logits,labels):
    return F.nll_loss(logits,labels)

  def training_step(self,train_batch,batch_idx):
    x,y=train_batch
    logits=self.forward(x)
    loss=self.cross_entropy_loss(logits,y)
    return  {'loss':loss}

  def training_epoch_end(self,outputs):
    avg_loss=torch.stack([x['loss'] for x in outputs]).mean()
    print('epoch={}, avg_Train_loss={:.2f}'.format(self.current_epoch,avg_loss.item()))
    # return {'avg_train_loss':avg_loss}

  # def validation_step(self,val_batch,batch_idx):
  #   x,y=val_batch
  #   logits=self.forward(x)
  #   loss=self.cross_entropy_loss(logits,y)
  #   return {'val_loss':loss}
  # def validation_epoch_end(self,outputs):
  #   avg_loss=torch.stack([x['val_loss'] for x in outputs]).mean()
  #   print('epoch={}, avg_Test_loss={:.2f}'.format(self.current_epoch,avg_loss.item()))
  #   return {'avg_val_loss':avg_loss}

model=LightningMNISTClassifier()
#from pytorch_lightning.callbacks import EarlyStopping
trainer=pl.Trainer(max_epochs=400,gpus=1,
                  #  check_val_every_n_epoch=2,
                  #  accumulate_grad_batches=5,
#                   early_stop_callback=early_stop,
                  #  limit_train_batches=50,
#                   val_check_interval=0.25,
                   progress_bar_refresh_rate=0,
#                   num_sanity_val_steps=0,
                   weights_summary=None)
clear_output(wait=True)
trainer.fit(model)Preformatted text

we tried a bunch of different learning rate values, however, the model (MNIST, PDE) is not converging, since the weights are not updating correctly. if there was an example using LBFGS optimizer, we could debug our code better.

The way you are using the LBFGS is correct. I don’t think there is a problem with that. Also in the case of MNIST, I just checked whether weights were updating or not although it isn’t converging as compared to Adam. Maybe you can try it with native pytorch code and see if it converges there or not. If it does then probably there might be a bug here that needs to be fixed.

trying LBFGS with native pytorch code, it works fine and the weights are updating and model is converging. As you mentioned, there might be a bug that needs to be fixed.
—> using LBFGS optimizer in pytorch lightening the model is not converging as compared to native pytoch + LBFGS #4083