ModelCheckpoint filename

Hi,
I have set up ModelCheckpoint as follows.

checkpoint_callback = ModelCheckpoint(
    filename='{epoch}-{step}',
    save_top_k=-1,
    every_n_train_steps=50000)

The file names will look like this

epoch=249-step=49999.ckpt
epoch=499-step=99999.ckpt
epoch=749-step=149999.ckpt

I think this is because the step starts at 0. I would like to name the file as follows, how do I do this?

epoch=250-step=50000.cpkt
epoch=500-step=100000.ckpt
epoch=750-step=150000.ckpt

you can customize the checkpoint a little:

class CustomModelCheckpoint(ModelCheckpoint):
    def format_checkpoint_name(self, metrics, filename, ver):
        increm_metrics = deepcopy(metrics)
        metrics['epoch'] += 1
        metrics['step'] += 1
        return super().format_checkpoint_name(increm_metrics, filename, ver)

trainer = Trainer(..., callbacks=[CustomModelCheckpoint(...)])

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only after some time.

Thank you

2 Likes

Thank you for your advice!
The problem has been solved.

class CustomModelCheckpoint(ModelCheckpoint):
    def format_checkpoint_name(self, metrics, filename=None, ver=None):
        increm_metrics = copy.deepcopy(metrics)
        increm_metrics['epoch'] += 1
        increm_metrics['step'] += 1
        return super().format_checkpoint_name(increm_metrics, filename, ver)