Calculating epoch_level metrics for checkpointing

Hello,

I am trying to calculate a metric on the entire validation set. The metric can not be computed on batches and can not be approximated from per_batch results. Ideally I would like to write checkpoints depending on this epoch level metric.

So far, I have found two options:

A) Add a Callback that runs a fn with a pass over the validation set to calculate the metric. I know that it is possible to log the metric to loggers from here (e.g. Tensorboard). But how would I add it to the Checkpoints-Callback? Is it possible to write to the Evalresult obj from other callbacks?

B) Use on_validation_epoch_end(). The same question as in A) applies.

What is the best way to implement this?

Many thanks for your help!

Would computing the metric in on_validation_epoch_end, then returning it in the dictionary output of that method? If so, in addition to setting the name of the monitor in the checkpoint callback to whatever the key is for that metric, I think that would work.

check the 2nd example here. You can access all prediction in epoch_end and compute the metric there.