Modifications of the general order of training/inference loops in PyTorch Lightning

I’m trying to figure out the best way to use PL for the following use case: normally when trainer.fit is called, the general order is

  1. Training loop
  2. Validation loop

However, for my use case, I need to run a loop of inference each time I run a training/validation loop to get some relevant information for the updates being done / for retrieving the relevant metrics in the training/validation loop, so the ordering I would like is something like

  1. Inference loop
  2. Training loop
  3. Inference loop
  4. Validation loop

Is this currently feasible with the Trainer class? I’d like to do this without writing a function like fit for this use case, but I can’t find a simple way of doing that so far. Thanks in advance!

I think what you’re looking for is a callback? that callback is called at whatever parts of the loop you want and you can do all the inference you need?

can you post pseudocode about what you are trying to do? just so we can understand the high-level idea

Sure, here’s some pseudocode for what I’m trying to do:

Given: model, train_dataloader, val_dataloader

for _ in range(max_epochs):
   pre_training_metrics = inference(model, train_dataloader)
   train_metrics = train(model, train_dataloader, pre_training_metrics)
   pre_validation_metrics = inference(model, val_dataloader)
   val_metrics = validate(model, val_dataloader, pre_validation_metrics)

If I understood you correctly, you’re saying that I could achieve the first and third lines of the loop using callbacks, but my concern with that is that I would need to explicitly write an inference loop to grab those pre_training/pre_validation metrics (those are used for e.g. deciding which examples in the dataset to weight more for gradient updates or for aggregating metrics).

Let me know if that makes sense. Thanks!