Unfortunately, that still counts the number of training steps elapsed (I assume because it records the
global_step as the step). See e.g. the tensorboard output:
The datapoints logged should be at step 0 or 1 (the last epoch) instead of 22 (the number of training batches). In other words, setting
does not affect how steps are logged.
Edit: here’s the relevant
def training_step(self, batch, _):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
res = pl.TrainResult(loss)
res.log("train_loss", loss, on_epoch=True, on_step=False)
train_epoch_end is defined.