Resume from checkpoint with elastic training

I use PyTorch Lightning with TorchElastic. My training function looks like this:

import pytorch_lightning as pl
# Each train() call runs as a single worker (one process)
def train(config: InputConfig):
    checkpoint_callback = pl.callbacks.ModelCheckpoint(...)
    module = MyLightningModule(config)
    trainer = pl.Trainer(num_nodes=..., gpus=..., checkpoint_callback=checkpoint_callback, ...)
    trainer.fit(module)
    return Results(...)

and I leverage torchelastic with the following

import torchelastic.distributed.local_launch as pet
...
def elastic_train(config: InputConfig):
    lc = pet.LocalLaunchConfig(
        # Assuming devgpu testing, min = max nodes = 1
        min_nodes=1,
        max_nodes=1,
        nproc_per_node=cfg.trainer.gpus if cfg.trainer.gpus else 1,
        # run_id just has to be globally unique
        run_id=f"your_run_identifier_{uuid4()}",
        # for fault tolerance; for testing set it to 0 (no fault tolerance)
        max_restarts=0,
        function_start_method="spawn",
    )
    # The "train" function is called inside the elastic_launch
    ret = pet.elastic_launch(lc, fn=train)(config)
    print(f"Rank 0 results = {ret[0]}")
def main(config: InputConfig):
    train_elastic(config) 

Sometimes training can fail. At this point, I’d like to resume training from the latest checkpoint. However, I don’t know the path to the latest checkpoint ahead of time, since my train function is wrapped by torchelastic, This means resume_from_checkpoint may not work for this use case because I don’t know the full path to the checkpoint ahead of time for when training fails, or if there’s even a valid path at all.

I don’t have any experience with TorchElastic, but perhaps you could pass your own ModelCheckpoint callback, with a defined filepath so you can always know where the checkpoint is saved.

This doesn’t work for us. On the first instance of training, there is no checkpoint to resume from. However, if training fails for some reason, then torchelastic will transparently restart the training job, and at that point we want to pick up the latest checkpoint and resume training.

We could hardcode this. One approach would be to determine the checkpoint directory ahead of time. If it exists, then we would try to find the latest checkpoint inside the directory to resume from, likely based on the naming or file creation time. But this feels a bit clunky. I’m open to suggestions for how we can write the training code to be as independent from restarting from failures as possible.

I see. You could do something like this:

CHECKPOINT = "hardcode/path/to/checkpoint.ckpt"

def train(model):
    if Path(CHECKPOINT).exists():
        trainer = pl.Trainer.resume_from_checkpoint(CHECKPOINT)
    else:
        checkpoint_callback = ModelCheckpoint(filepath=CHECKPOINT)
        trainer = pl.Trainer(callbacks=[checkpoint_callback])
    trainer.fit(model)

I agree that this does seem a little clunky; I will think if there is a better approach.

@teddy I read the checkpoint callback again and completely missed that save_last works for this use case! We can rely on the path there (if it exists) with resume_from_checkpoint . That should be straightforward for us to adopt

Awesome! Don’t hesitate to ask if you have any other issues :slight_smile: