How to tune Parameter using ray library with Pytorch lightning

I want to tune my hyper-parameters using ray-tune.

I am solving multi-label classification using BERT model.

What changes do I need to make to my code to fit Tuning parameters for Batch size and say learning rate

Here is my code step by step.

class SRDataset(Dataset):

  def __init__(
    self, 
    data: pd.DataFrame,
    labels: pd.DataFrame, 
    tokenizer: BertTokenizer, 
    max_token_len: int = 512
  ):
    self.tokenizer = tokenizer
    self.data = data
    self.labels = labels
    self.max_token_len = max_token_len
    
  def __len__(self):
    return len(self.data)

  def __getitem__(self, index: int):
    data_row = self.data[index]

    text_data = data_row
    labels_text = self.labels[index]

    encoding = self.tokenizer.encode_plus(
      text_data,
      add_special_tokens=True,
      max_length=self.max_token_len,
      return_token_type_ids=False,
      padding="max_length",
      truncation=True,
      return_attention_mask=True,
      return_tensors='pt',
    )

    return dict(
      text_data=text_data,
      input_ids=encoding["input_ids"].flatten(),
      attention_mask=encoding["attention_mask"].flatten(),
      labels=torch.FloatTensor(labels_text)
    )
class  SRDataModule(pl.LightningDataModule):
    
    def __init__(self, X_train,y_train, X_test,y_test, tokenizer, batch_size=10, max_token_len=512):
        super().__init__()
        self.batch_size = batch_size
        self.train_df = X_train
        self.test_df = X_test
        self.train_lab = y_train
        self.test_lab = y_test
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len

    def setup(self, stage=None):
        self.train_dataset = SRDataset(
          self.train_df,
          self.train_lab,
          self.tokenizer,
          self.max_token_len
        )

        self.test_dataset = SRDataset(
          self.test_df,
          self.test_lab,
          self.tokenizer,
          self.max_token_len
    )

    def train_dataloader(self):
        return DataLoader(
          self.train_dataset,
          batch_size=self.batch_size,
          shuffle=True,
          num_workers=50
        )

    def val_dataloader(self):
        return DataLoader(
          self.test_dataset,
          batch_size=self.batch_size,
          num_workers=50
        )

    def test_dataloader(self):
        return DataLoader(
          self.test_dataset,
          batch_size=self.batch_size,
          num_workers=50
        )

Model

N_EPOCHS = 10
BATCH_SIZE = 10

data_module = SRDataModule(
  X_train,
  y_train,
  X_test,
  y_test,
  tokenizer,
  batch_size=BATCH_SIZE,
  max_token_len=MAX_TOKEN_COUNT
)
class SRTagger(pl.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None,lr=None):
    super().__init__()
    self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
    self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss()
    self.lr=lr

  def forward(self, input_ids, attention_mask, labels=None):
    output = self.bert(input_ids, attention_mask=attention_mask)
    output = self.classifier(output.pooler_output)
    output = torch.sigmoid(output)    
    loss = 0
    if labels is not None:
        loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("val_loss", loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("test_loss", loss, prog_bar=True, logger=True)
    return loss

  def training_epoch_end(self, outputs):
    
    labels = []
    predictions = []
    for output in outputs:
      for out_labels in output["labels"].detach().cpu():
        labels.append(out_labels)
      for out_predictions in output["predictions"].detach().cpu():
        predictions.append(out_predictions)

    labels = torch.stack(labels).int()
    predictions = torch.stack(predictions)

    for i, name in enumerate(LABEL_COLUMNS):
      class_roc_auc = auroc(predictions[:, i], labels[:, i])
      self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)


  def configure_optimizers(self):

    optimizer = optim.RAdam(self.parameters(), lr=2e-5)

    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=self.n_warmup_steps,
      num_training_steps=self.n_training_steps
    )

    return dict(
      optimizer=optimizer,
      lr_scheduler=dict(
        scheduler=scheduler,
        interval='step'
      )
    )

Running the model

steps_per_epoch=len(X_train) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS
# We'll use a fifth of the training steps for a warm-up:


warmup_steps = total_training_steps // 5
warmup_steps, total_training_steps


# We can now create an instance of our model:
model = SRTagger(
  n_classes=100,
  n_warmup_steps=warmup_steps,
  n_training_steps=total_training_steps,
  lr=2e-5
)

checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints_sample_exp",
  filename="best-checkpoint_exp",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)

logger = TensorBoardLogger("lightning_logs", name="SR")
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
trainer = pl.Trainer(
  logger=logger,
  callbacks=[early_stopping_callback,checkpoint_callback],
  max_epochs=N_EPOCHS,
  gpus=1,
  progress_bar_refresh_rate=200,
  amp_level='O3'

  )

trainer.fit(model, data_module)