Use SSLOnlineEvaluator for more that one epoch

Hey there,

PL Bolts has this nice callback called SSLOnlineEvaluator to evaluate your model by stacking a (mlp) model on top of the features and training it to assess whether the features are meaningful, as done in SimCLR: 266-270.

My question is if it is possible to train the mlp for more than one epoch?
An unclean solution could be something like:

def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
    if batch_idx == 0:
        train_loader = trainer.train_loader
        epochs = 5
        for epoch in range(epochs):
            for batch in train_loader:
                # forward + backward of mlp 

But I think that this is not the way the callback is supposed to be used.



This seems like the best solution at the moment, as callbacks are not really designed to train separate models in them to begin with. If this starts to become a common use we might be able to make improvements, but in the mean time I think this is the best you can do.