Training for a set number of iterations without setting epochs?

Can we train for a set number of iterations without setting max_epochs ? I now know there are max_steps and limit_train_batches to control how many training steps are taken overall and training steps per epoch, respectively. However, max_steps runs into issues if max_steps * batch_size > len(train_dataset) What would break if epochs was an optional argument to the trainer? Is there an alternative, recommended way to convert training loops of this style over to lightning?

a super super hacky solution would be setting max_epochs to be an outrageously large value and set max_steps to my desired iteration count. then we’ll break from the train loop according to max_steps since we’d hit this first instead of max_epochs

as suggested by @awaelchli

we could maybe the solutions is simple:
make both optional (default None)
if both are unset: set max_epochs=1000 (current)
if max_steps is set: use that one (keep max_epochs =None)
if both are set, stop based on whatever condition is met first.

What issues?

Aside from this I agree with @awaelchli, perhaps you can file an issue (or even a PR). In the mean time I agree the only workaround would be setting an absurdly high max_epochs.

What issues?

You need to wrap around the dataloader. if max_epochs=1, then this could signal to lightning that the training is done, and you don’t reach max_steps.

Other things that i’m not sure about:

  • what happens to the callback events for on_epoch_start/end?
  • what happens to dataloader resetting every epoch?
  • what happens to epoch for checkpointing? would resuming from checkpoint still work as expected? is it possible to resume at the current step instead of at the previous epoch? this requires mid-epoch checkpointing support

This is only an issue if you set max_epochs to 1. If you set it arbitrarily high (or perhaps None in the future,) lightning will only stop when steps reaches max_steps.

Still execute as usual at start and end of epoch.

If you are resetting the dataloader via reload_dataloaders_every_epoch, they will still be reloaded every epoch.

You can use val_check_interval to check your validation metrics and checkpoint every n steps.

filed https://github.com/PyTorchLightning/pytorch-lightning/issues/3521 to add this to the trainer