Hi, I’m trying to refactor the official NLP (sentiment analysis) tutorial, using Lightning in order to take advantage of things like early stopping etc.
I’m moving first steps, and the main hurdle is the creation of a Lightning module, and in particular coding the training_step
.
What I came up so far is
class LitTextClassifier(pl.LightningModule):
def __init__(self, num_class, criterion = CrossEntropyLoss):
super().__init__()
self.embedding = nn.EmbeddingBag(VOCAB_SIZE, EMBED_DIM, sparse=False)
self.fc = nn.Linear(EMBED_DIM, num_class)
self.init_weights()
self.criterion = criterion
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=4.0)
return optimizer
def training_step(self, batch, batch_idx):
# I am messing up things here
text, offsets, cls = batch
output = self.forward(text, offsets)
loss = self.criterion(output, cls)
return loss
But I am obviously getting the training_step
wrong. Can someone provide guidance here?
A full gist to reproduce code + errors I get is here: Text classification in PyTorch to refactor with PyTorch lightning.ipynb · GitHub