Regularization-based Continual Learning

I’m new to PyTorch Lightning. I’m planning to experiment with different methods for continual learning (CL), i.e., training the same model continually on a sequence of tasks. CL methods roughly fall into two classes:

Memory-based methods, which keep a subset of data from old tasks in a memory and replay those data points when training on new tasks. These seem to be straight-forward to implement using Lightning by wrapping the Trainer.fit method in some code that handles the subsampling and feeds the data into it.

Regularization-based methods (the prime example being elastic weight consolidation) use regularizers on the model parameters that encourage it to remember old tasks. These regularizers need to be updated after processing a new task. Since Lightning defines the loss in the LightningModule.training_step method, I’m having a hard time figuring out how to best handle a regularizer that needs to be changed dynamically. Does anyone have any ideas?