Saving/loading LightningModule with injected network

My problem

I’d like to ask the community for recommendations on how to implement save/load operations for LightningModules that are organized in a particular pattern. The pattern I’m talking about comes from the docs’ recommendation on how to setup models for production:

class ClassificationTask(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

    # The rest of the module defines `__step` computations and optimizers

Here, the LightningModule describes computations rather than a network; the network itself is injected in the module. I really like this approach (that is quite modular and easily configurable IMO), but I’m a bit puzzled on how to make saving/loading modules work with it.

What I tried

I figure that I have to rebuild model myself and inject it during the loading of ClassificationTask, however I’m hitting a wall when trying to make it work in practice. What I’m doing is something like this:

class ClassificationTask(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

    def load_from_checkpoint(...):
        # Rebuild `model` from the configuration stored in the checkpoint
        model = ...

        # Here is the tricky part
        super().load_from_checkpoint(..., model=model)

Issues with my solution so far

When calling super().load_from_checkpoint(...), I thought I could just inject model there and be done with it; it would be forwarded to the class’ __init__ and all would be well. However, digging a little deeper in the code, I came across this bit in the load_from_checkpoint base implementation in LightningModule:

# for past checkpoint need to add the new key
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
    checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
# override the hparams with values that were passed in
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)

model = cls._load_model_state(checkpoint, strict=strict, **kwargs)

In my case, kwargs would be {'model': model}. The _load_model_state call means it will indeed get forwarded to the class __init__ eventually, however it’s also getting added to checkpoint’s hyperparameters, which I would want to avoid polluting with the injected model.

Thus follows my question: would you have any recommendation on how to modify my setup so that I can inject the network inside the LightingModule upon loading it, without having the network be added to the checkpoint’s hyperparams?

Additional context

You might be wondering why I’m so insistent about not having model be added to the checkpoint’s hyperparams. Well, I’m trying to setup a project seed merging Lightning and Hydra. Thus, my hyperparams are a typed structured config that straight up refuses to receive a custom class. The error message I receive is something like:

omegaconf.errors.UnsupportedValueType: Value 'model.__class__' is not a supported primitive type

Furthermore, as a general principle, I think it’s best to save/load strict hyperparams, to ease reproducibility and portability. I’ve had previous bad experiences with internal states that varied a little depending on whether they where built from scratch or loaded from a checkpoint.

Thanks in advance for your help/recommendations!

EDIT: I’d have added more links to Hydra’s documentation to give better context, but as a new user of the forum I’m unfortunately limited to 2 links by post :cry: