Hi all,
I’m working on a EMA(exponential moving average) callback that maintains exponential moving average for parameters and run eval step with the non-ema model side by side. I’m currently considering this approach:
- make a copy of the Lightning module in the callback init
- update the weights of EMA model at each
on_train_batch_end
- in the
on_validation_end
hook, read the same input batch and give it to the EMA model and log the outputs
Is there a way to read the input batch in a hook of callback?