Computing validation accuracy at the end of each epoch


I’m currently working on a computer vision project on CIFAR10, using Lightning (and Bolts) to do most of the heavy lifting. I was wondering how you should go about computing the validation accuracy after each epoch, to observe how it changes as the training progresses. Are there any specific things to use or watch out for?


Hi Gillles, great question! To compute epoch-level metrics (such as accuracy) you can do something like this:

import pytorch_lightning.metrics.functional as FM

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    result = EvalResult()
    result.y = y
    result.y_hat = y_hat
    return result

def validation_epoch_end(self, out):
    result = EvalResult()
    accuracy = FM.accuracy(out.y_hat, out.y)
    result.log_dict({"valid_acc": accuracy})
    return result

Hope this helps!