How to get path where checkpoints are saved in a Callback?

Goal

Save a epoch=x.pt (TorchScript) file in the same directory as epoch=x.cpkt.

In more detail

Besides saving the weights, I also want to save a TorchScript version of the model when the on_checkpoint is called (e.g. a lower val_loss has been achieved). However, I cannot extract where this checkpoint is being saved.

Example code

Say I have the following code:

class OnCheckpointTorchScript(Callback):
    def on_save_checkpoint(self, trainer, pl_module):
        file_path =    # HOW? trainer._weights_save_path
        self.to_torchscript(file_path=f"{file_path}[[:-4]pt", torchscript_approach='trace')

How do I set file_path to match weights_save_path/lightning_logs/version_x/checkpoints/epoch=2.ckpt?
trainer._weights_save_path only contains the directory your code is called from (when not changed), but I need the full path, including version_x and epoch no.

Question

How do I get the location of epoch=2.ckpt in a Callback?

working on colab example, stay tuned.

  • trainer.logger.log_dir will give you something like '/content/lightning_logs/version_1'
  • trainer.logger.version will give you '1' (if on version 1)

using that logic, you could probably grab epoch from trainer.current_epoch and then save your file on on_save_checkpoint.

Does this solve your issue? if so I can go ahead and make the colab…just don’t want to spend time writing it up if this doesn’t answer your question.

Working callback function

trainer.logger.log_dir is what I needed! Thank you. Now the callback function to save the TorchScript becomes:

class OnCheckpointTorchScript(Callback):
    def on_save_checkpoint(self, trainer, pl_module):  # : Dict[str, Any] -> None:
        file_path = f"{trainer.logger.log_dir}/checkpoints/epoch={trainer.current_epoch}.pt"
        pl_module.to_torchscript(file_path=file_path, method='trace')

I already looked through trainer.__dict__ and trainer.logger.__dict__. While I could find current_epoch and trainer._weights_save_path, log_dir did not show up. How did you know about this variable?

An extra question: Automatically delete the .pt file

PyTorch Lightning automatically only keeps the best epoch (or multiple epochs if you set save_top_k).
How do I enable the same logic for my TorchScript .pt file? I want to delete .pt files when the matching .ckpt file is deleted.

you can also use trainer.checkpoint_callback.dirpath.

maybe do a brute force search and delete the ones if not found with matching name after checkpoints are saved.