BERT model throws error when used in Pytorch Lightning

Hi,

I’m working on a text classification problem using Huggingface’s transformers library. Its a author prediction problem and I am using Huggingface’s datasets library to get my data up and ready for processing. After creating my dataloader, I am able to get a single batch out and run through the model to get the loss. However, when I do the same thing with a Lightning module and call trainer.fit, I get an error.

Here is the code snippet to run without Lightning:

args = Namespace(
  model_name='bert-base-uncased',
  max_seq_len=128,
  batch_size=2,
  n_labels=10,
  learning_rate=1e-5,
)
train_ds,test_ds = dss['train'], dss['test']
train_ds.set_format(type='pt', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
train_dl = DataLoader(train_ds, batch_size=args.batch_size)
itr = iter(train_dl)

idx = np.random.randint(len(train_ds))
input_ids = train_ds[idx]['input_ids']
labels = train_ds[idx]['labels'].item()
print(decode_ex(tokenizer, label2auth, input_ids, labels))

Decoded Text:

friend. what if the lad needed him? or what if he didn ’ t? making up his mind, zayn was heading to the door. before he could make it, a certain curly - haired lad appeared. “ harry! ” zayn cried out in relief but it was short - lived because he noticed the slight sway in the way that the lanky lad stood. oh, zayn was going to murder that tomlinson kid after he was done with harry. this could be his take on a modern romeo and juliet. drunk harry was a giggly and naive harry. anyone could have taken advantage of him. za

Author: author5

model = AutoModelForSequenceClassification.from_pretrained(args.model_name, num_labels=args.n_labels)
batch = itr.next()
model(**batch)
(tensor(2.2697, grad_fn=<NllLossBackward>), tensor([[-0.2113,  0.1161, -0.4819, -0.2238,  0.1152,  0.1523, -0.1757,  0.7545,
         -0.2766, -0.2114],
        [-0.1257, -0.0635, -0.5627, -0.1476,  0.0796,  0.2059,  0.0010,  0.8079,
         -0.5680, -0.3751]], grad_fn=<AddmmBackward>))

Here is my Lightning module:

class PANAuthorAttrib(pl.LightningModule):
  def __init__(self, hparams):
    super(PANAuthorAttrib, self).__init__()
    self.hparams = hparams
    self.model = AutoModelForSequenceClassification.from_pretrained(self.hparams.model_name, num_labels=self.hparams.n_labels)
    
  def forward(self, input_ids, attention_mask, token_type_ids, labels):
    return self.model(input_ids, attention_mask, token_type_ids, labels)
    
  def training_step(self, batch, batch_idx):    
    outputs = self(**batch)
    loss = outputs[0]
    return loss
  
  def configure_optimizers(self):
    return optim.AdamW(params=self.parameters(), lr=self.hparams.learning_rate)
model = PANAuthorAttrib(args)
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model, train_dl)

This gives the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-15-e2c2f2d0bf62> in <module>
----> 1 trainer.fit(model, train_dl)

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    438         self.call_hook('on_fit_start')
    439 
--> 440         results = self.accelerator_backend.train()
    441         self.accelerator_backend.teardown()
    442 

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py in train(self)
     46 
     47         # train or test
---> 48         results = self.train_or_test()
     49         return results
     50 

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in train_or_test(self)
     64             results = self.trainer.run_test()
     65         else:
---> 66             results = self.trainer.train()
     67         return results
     68 

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in train(self)
    481 
    482                 # run train epoch
--> 483                 self.train_loop.run_training_epoch()
    484 
    485                 if self.max_steps and self.max_steps <= self.global_step:

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    532             # TRAINING_STEP + TRAINING_STEP_END
    533             # ------------------------------------
--> 534             batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    535 
    536             # when returning -1 from train_step, we end epoch early

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
    669                     opt_idx,
    670                     optimizer,
--> 671                     self.trainer.hiddens
    672                 )
    673 

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    751         """
    752         # lightning module hook
--> 753         result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    754 
    755         if result is None:

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    302         with self.trainer.profiler.profile('model_forward'):
    303             args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
--> 304             training_step_output = self.trainer.accelerator_backend.training_step(args)
    305             training_step_output = self.trainer.call_hook('training_step_end', training_step_output)
    306 

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py in training_step(self, args)
     54                 output = self.trainer.model.training_step(*args)
     55         else:
---> 56             output = self.trainer.model.training_step(*args)
     57         return output
     58 

<ipython-input-13-c94823a8e160> in training_step(self, batch, batch_idx)
      9 
     10   def training_step(self, batch, batch_idx):
---> 11     outputs = self(**batch)
     12     loss = outputs[0]
     13     return loss

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

<ipython-input-13-c94823a8e160> in forward(self, input_ids, attention_mask, token_type_ids, labels)
      6 
      7   def forward(self, input_ids, attention_mask, token_type_ids, labels):
----> 8     return self.model(input_ids, attention_mask, token_type_ids, labels)
      9 
     10   def training_step(self, batch, batch_idx):

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/transformers/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)
   1342             output_attentions=output_attentions,
   1343             output_hidden_states=output_hidden_states,
-> 1344             return_dict=return_dict,
   1345         )
   1346 

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/transformers/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, output_attentions, output_hidden_states, return_dict)
    829 
    830         embedding_output = self.embeddings(
--> 831             input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
    832         )
    833         encoder_outputs = self.encoder(

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/net/vaosl01/opt/NFS/su0/miniconda3/envs/auth/lib/python3.7/site-packages/transformers/modeling_bert.py in forward(self, input_ids, token_type_ids, position_ids, inputs_embeds)
    199         token_type_embeddings = self.token_type_embeddings(token_type_ids)
    200 
--> 201         embeddings = inputs_embeds + position_embeddings + token_type_embeddings
    202         embeddings = self.LayerNorm(embeddings)
    203         embeddings = self.dropout(embeddings)

RuntimeError: The size of tensor a (128) must match the size of tensor b (2) at non-singleton dimension 1

I’m not sure why this is happening. I’m running the same batch through both (dataloader is not shuffled). The program crashes when attempting to compute the input embeddings for the words:

ipdb> inputs_embeds.shape
torch.Size([2, 128, 768])
ipdb> position_embeddings.shape
torch.Size([2, 768])
ipdb> token_type_embeddings.shape
torch.Size([2, 128, 768])

position_embeddings shape is torch.Size([2, 768]) and is being added to tensors of size torch.Size([2, 128, 768]) which seems to be a broadcasting error. But I don’t why this is happening in only Lightning and not when I just run through the model outside of Lightning. Could use some help on this.

Thanks.

I think some mismatch with positional arguments is there:

try:

or maybe

OMG…thank you. That was so dumb of me.