We have a model where we want to do an epoch or two towards a cheap-to-calculate initial condition and then continue using those weights as the initial condition for the real training/validation/tuning/etc.
Lets say the original code was something like
parser = ArgumentParser()
parser = MyModel.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args, unknown = parser.parse_known_args()
dict_args = vars(args)
model = ModelClass(**dict_args)
trainer = pl.Trainer.from_argparse_args(args)
trainer.tune(model)
trainer.fit(model)
trainer.test()
What is the best pattern to do this on grid and not clutter the logs/add checkpoints/etc. or mess up lighting and grid commandline arguments?
For example, lets say I have the flag fit_initial_condition
in my model which is defaulted to False. I want to do up to fit_initial_epochs
. Is something like the following pattern a good idea?
parser = ArgumentParser()
parser = MyModel.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args, unknown = parser.parse_known_args()
dict_args = vars(args)
# Fit to initial condition
if dict_args["fit_initial_epochs"] > 0:
fit_initial_params = {"fit_initial_condition": True,
"max_epochs": dict_args["fit_initial_epochs"],
"num_sanity_val_steps": 0,
"check_val_every_n_epoch": 1000,
"logger": None,
"checkpoint_callback": False,
}
ic_dict_args = dict_args.copy()
ic_dict_args.update(fit_initial_params)
ic_model = ModelClass(**ic_dict_args)
ic_trainer = pl.Trainer(**ic_dict_args)
ic_trainer.fit(model)
# Get the weights
ic_weights = model.state_dict()
else:
ic_weights = None
# Now create the proper model and trainer
model = ModelClass(**dict_args)
# Load weights if trained with initial condition.
if not ic_weights is None:
model.load_state_dict(best_model_wts)
trainer.tune(model)
trainer.fit(model)
trainer.test()
The hope is that it save chekpoints, write to tensorboard logs, etc. If this is a reasonable way to solve the problem, are there settings that are missing?