Best practice to partially train to a different function prior to full training

We have a model where we want to do an epoch or two towards a cheap-to-calculate initial condition and then continue using those weights as the initial condition for the real training/validation/tuning/etc.
Lets say the original code was something like

parser = ArgumentParser()
parser = MyModel.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args, unknown = parser.parse_known_args()
dict_args = vars(args)
model = ModelClass(**dict_args)
trainer = pl.Trainer.from_argparse_args(args)
trainer.tune(model)
trainer.fit(model)
trainer.test()

What is the best pattern to do this on grid and not clutter the logs/add checkpoints/etc. or mess up lighting and grid commandline arguments?

For example, lets say I have the flag fit_initial_condition in my model which is defaulted to False. I want to do up to fit_initial_epochs. Is something like the following pattern a good idea?

parser = ArgumentParser()
parser = MyModel.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args, unknown = parser.parse_known_args()
dict_args = vars(args)

# Fit to initial condition
if dict_args["fit_initial_epochs"] > 0:
    fit_initial_params = {"fit_initial_condition": True,
                          "max_epochs": dict_args["fit_initial_epochs"],
                          "num_sanity_val_steps": 0,
                          "check_val_every_n_epoch": 1000,
                          "logger": None,
                          "checkpoint_callback": False,
                          }
    ic_dict_args = dict_args.copy()
    ic_dict_args.update(fit_initial_params)
    ic_model = ModelClass(**ic_dict_args)
    ic_trainer = pl.Trainer(**ic_dict_args)
    ic_trainer.fit(model)

    # Get the weights
    ic_weights = model.state_dict()
else:
    ic_weights = None
    
# Now create the proper model and trainer
model = ModelClass(**dict_args)

# Load weights if trained with initial condition.
if not ic_weights is None:
    model.load_state_dict(best_model_wts)

trainer.tune(model)
trainer.fit(model)
trainer.test()

The hope is that it save chekpoints, write to tensorboard logs, etc. If this is a reasonable way to solve the problem, are there settings that are missing?

If I’m not mistaken, it won’t be possible to do both tune and fit in the same script if running on multiple gpus.

Please, @tchaton weigh in here when ya get the chance.

Ah I see. If using a single GPU is there a chance?