PL Bolts has this nice callback called SSLOnlineEvaluator to evaluate your model by stacking a (mlp) model on top of the features and training it to assess whether the features are meaningful, as done in SimCLR: 266-270.
My question is if it is possible to train the mlp for more than one epoch?
An unclean solution could be something like:
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): if batch_idx == 0: train_loader = trainer.train_loader epochs = 5 for epoch in range(epochs): for batch in train_loader: # forward + backward of mlp
But I think that this is not the way the callback is supposed to be used.