StochasticWeightAveraging validation logging and checkpoints


I’m trying to understand how to use StochasticWeightAveraging properly. Have a first look at the code, it looks to me that the average model is only saved back to the actual model after training is finished at max_epochs. I want to check the implications of this:

(a) Does this mean any metrics calculated during the validation steps will not use the average model, but the raw model being trained?
(b) And if I load a model using the usual checkpoint mechanics for any checkpoint before max_epochs will this be the raw training model and not the average model?

Thanks for any help on this!

EDIT: just in case it’s useful to understand why I’m wondering this, I don’t tend to set a max_epochs and usually stop training manually when I see overfitting or convergance. For the validation part of the question, it would be nice to monitor the performance of SWA in my dashboard (e.g. wandb).