Initialize model with data before training

Background:
For anomaly detection I use a model with a feature memory bank.
Before training, the memory bank needs to be initialized with the mean of all features extracted from the training set. Afterwards the memory bank can be trained like a “normal” network over several epochs.

Question:
Using a trainer how can I initialize the memory bank before training? My initial thought was to do it in the on_fit_start callback but doing so I struggel to get the data on to the gpu.

class InitCallback(Callback):
    def on_fit_start(self, trainer, pl_module):
        data_module = trainer.datamodule
        data_loader = data_module.train_dataloader()
        pl_module.init_memory_bank(data_loader)

Is there a better way/place of doing the initialization and is there a way to manually move the data to the gpu inside the callback?