Loading fine-tuned model built from pretrained subnetworks

Hello everyone,

I would like to ask for confirmation if I get the expected behaviour please and if there would be best practices to handle the following situation.

I have two LightningModule that I call e.g. model_1 and model_2, which I pretrain separately. After saving them, I get ckpt_1,yaml_1 and ckpt_2,yaml_2 which describe their trained parameters and hyper-parameters.

Now I put them together in a model e.g. combined_model and I fine-tune them on the task of the combined_model.

class combined_model(pl.LightningModule):
def init(self,ckpt_1="",yaml_1="",ckpt_2="",yaml_2="",…):
self.model_1 = model_1.load_from_checkpoint(checkpoint_path=ckpt_1,hparams_file=yaml_1,map_location=‘cpu’)
self.model_2 = model_2.load_from_checkpoint(checkpoint_path=ckpt_2,hparams_file=yaml_2,map_location=‘cpu’)

At the beginning of the fine-tuning I build the model as:
→ combined_model optimizes the trainable parameters of model_1 and model_2, starting from the pretrained checkpoints, right ?

After the fine-tuning is done, I have ckpt_3 and yaml_3 which give the fine-tuned parameters and the destinations of the pretrained checkpoints used to build combined_model.

Usually I could just restore the fine-tuned model as

The problem I have is working with remote servers, the paths change in between the fine-tuning run and another test run so in the end yaml_3 point to wrong paths for ckpt_1,yaml_1 and ckpt_2,yaml_2 when I want to restore the fine-tuned combined_model.

What I do then is that I manually specify these new paths ckpt_1bis,yaml_1bis and ckpt_2bis,yaml_2bis in
→ in this case, am I for sure properly loading the fine-tuned weights of ckpt_3 and not the pretrained weights of ckpt_1bis and ckpt_2bis ?

I think so but I would like to be sure and also, are there any recommended ways to better handle this situation please ?

Thanks !

hey @AdrienB

yes, this looks correct. Also you can customize the loading process a little bit to avoid ckpt_1 and ckpt_2 on your remote servers.

class combined_model(pl.LightningModule):
    def init(self,ckpt_1="",yaml_1="",ckpt_2="",yaml_2="",…):
        if ckpt_1 is None:
            self.model_1 = model_1()
            self.model_1 = model_1.load_from_checkpoint(checkpoint_path=ckpt_1,hparams_file=yaml_1,map_location=‘cpu’)
        # same for model_2

and during reloading on server:


Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only soon.

Thank you