Is there a way to only log on epoch end using the new Result APIs?

Asked by Brian Chen.

Is there a way to only log on epoch end using the new Result APIs? I was able to do this by forcing step to be self.current_epoch (essentially ignoring global_step ) with the old dict interface, but TrainResult doesn’t appear to have a step field?

.log(..., on_step=False, on_epoch=True)

In general:

on_step = log at each step
on_epoch = aggregate and log at the epoch

combine them how you wish!

Unfortunately, that still counts the number of training steps elapsed (I assume because it records the global_step as the step). See e.g. the tensorboard output:

The datapoints logged should be at step 0 or 1 (the last epoch) instead of 22 (the number of training batches). In other words, setting on_step=False, on_epoch=True does not affect how steps are logged.

Edit: here’s the relevant training_step implementation:

    def training_step(self, batch, _):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        res = pl.TrainResult(loss)
        res.log("train_loss", loss, on_epoch=True, on_step=False)
        return res

No train_epoch_end is defined.

1 Like

You can customize what happens at the end of a training epoch (click on this link for documentation). You can add an EvalResult logger in it,

def training_epoch_end(self, training_step_outputs):
    print('training steps', training_step_outputs)
    avg_loss = training_step_outputs.loss.mean()
    result = pl.EvalResult(checkpoint_on=avg_loss)
    result.log('train_loss', avg_loss, on_epoch=True, prog_bar=True)

in the training step, you can store loss values in the result to do any operations on it.

result.loss = loss
1 Like

ok, good point. i guess we can map the x axis to epochs instead of steps here. mind submitting a GH issue and linking back here?

1 Like

Interesting, I never thought about returning an EvalResult from the training step methods! Unfortunately that still doesn’t log the correct step, but I was able to do so by adding result.log('step', self.current_epoch, ...).


oh wait… yeah, the point of evalresult is that you don’t need an epoch_end method…