How to obtain per-class accuracy at the end of each epoch?

I know how to calculate and accumulate confusion matrix and then calculate per class accuracy and overall accuracy with pytorch. However, I found it was not easy to do that with pytorch lightning.

Is there any way in pytorch lightning that can generate or log per-class accuracy for the entire validation dataset?

I found an example online to log “accuracy” like this:

class MyModel(LightningModule):

    def __init__(self):
        ...
        self.accuracy = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        # log step metric
        self.accuracy(preds, y)
        self.log('train_acc_step', self.accuracy)
        ...

    def training_epoch_end(self, outs):
        # log epoch metric
        self.log('train_acc_epoch', self.accuracy)

But I don’t know how to obtain per-class accuracy here. For example, if I calculate a confusion matrix in training_step(), the confusion matrix is only for a single batch and it is possible some classes don’t even exist in this batch.

def training_step(self, batch, batch_idx):
        ... # after computing accuracy per class
       for i, acc in enumerate(accs): # accs : accuracy per class
              self.log(f'train_acc_step_class_{i}', acc)

Not sure if it’s exactly what you want, I hope it can help.