Read input batch in callback

Hi all,
I’m working on a EMA(exponential moving average) callback that maintains exponential moving average for parameters and run eval step with the non-ema model side by side. I’m currently considering this approach:

  1. make a copy of the Lightning module in the callback init
  2. update the weights of EMA model at each on_train_batch_end
  3. in the on_validation_end hook, read the same input batch and give it to the EMA model and log the outputs

Is there a way to read the input batch in a hook of callback?

Hello, my apology for the late reply. We are slowly converging to deprecate this forum in favor of the GH build-in version… Could we kindly ask you to recreate your question there - Lightning Discussions