How to add a pretrain step in training_step

Hi there, I have finished my codebase using the Pytorch Lightning module. I understood that the LightningModule is divided into steps like training_step, test_step, and validation_step, my question is how and where to add the pretrain step, with target epochs. Here is my code piece, thanks so much ahead!

class GMVAEModel(pl.LightningModule):
    def __init__(
        self,
        input_size=104,
        gaussian_size=10,
        num_classes=10,
        lr=0.0001,
        w_cat=None,
        w_gauss=None,
        w_rec=None,
        rec_type="mse",
        processed_zarr_dataset_path=None,
        plot_graph_size=200,
        log_path="",
        k=15,
        use_gmm=True,
        *args,
        **kwargs,
    ):
        """GMVAEModel, inheriting the LightningModule, need to implement
        the training_step, validation_step, test_step and optimizers.
        
        Args:
            input_size (int): size of input vector, 104 in current settings.
            gaussian_size (int): size of latent space vector.
            num_classes (int): size of the gaussian mixture.
            lr (float): learning rate of the models.
            k (int): k neighbors used when plotting KNN graph.
            w_cat (float): categorical loss weights.
            w_gaussian (float): gaussian loss weights.
            w_rec (float): reconstruction loss weights.
            rec_type (string): reconstruction loss type, supporting 'mse', 'bce' currently.
            processed_zarr_dataset_path (string): path of processed zarr dataset,
                for logging the ag graph.
            plot_graph_size (int): size of logging graph in wandb.
            log_path (string): path to save the logging result.
            use_gmm (boolean): whether use gmm fit the latent vector z.

        Attrs:
        """
        super().__init__()
        self.network = GMVAENet(
            x_dim = input_size,
            z_dim = gaussian_size,
            y_dim = num_classes,
        )
        self.lr = lr
        self.w_cat = w_cat
        self.w_gauss = w_gauss
        self.w_rec = w_rec
        self.rec_type = rec_type
        self.k = k
        self.losses = LossFunctions()
        self.processed_zarr_dataset_path = processed_zarr_dataset_path
        self.plot_graph_size = plot_graph_size
        self.log_path = log_path
        self.use_gmm = use_gmm
        self.gmm = GaussianMixture(n_components=num_classes) if self.use_gmm else None
        
    def unlabeled_loss(self, data, out_net):
        z, data_recon = out_net["gaussian"], out_net["x_rec"]
        logits, prob_cat = out_net["logits"], out_net["prob_cat"]
        y_mu, y_var = out_net["y_mean"], out_net["y_var"]
        mu, var = out_net["mean"], out_net["var"]

        loss_rec = self.losses.reconstruction_loss(data, data_recon, rec_type=self.rec_type)
        loss_gauss = self.losses.gaussian_loss(z, mu, var, y_mu, y_var)
        loss_cat = -self.losses.entropy(logits, prob_cat) - np.log(0.1)
        loss_total = self.w_rec * loss_rec + self.w_gauss * loss_gauss + self.w_cat * loss_cat

        predicted_clusters = prob_cat.argmax(-1)
        highest_probs = prob_cat.max(-1).values
        loss_dict = {
            "total": loss_total,
            "predicted_clusters": predicted_clusters,
            "reconstruction": loss_rec * self.w_rec,
            "gaussian": loss_gauss,
            "categorical": loss_cat,
            "highest_prob": highest_probs,
        }
        return loss_dict

    def forward(self):
        pass
        
    def training_step(self, batch, batch_idx):        
        attributes = batch["feature"]
        out_net = self.network(attributes)
        loss_dict = self.unlabeled_loss(attributes, out_net)

        loss = loss_dict["total"]
        reconstruction_loss = loss_dict["reconstruction"]
        gaussian_loss = loss_dict["gaussian"]
        categorical_loss = loss_dict["categorical"]
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("train/reconstruction_loss", reconstruction_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("train/gaussian_loss", gaussian_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("train/categorical_loss", categorical_loss, on_step=False, on_epoch=True, prog_bar=False)
        return {"loss": loss}
    
    def test_step(self):
        pass
    
    def validation_step(self, batch, batch_idx):
        attributes = batch["feature"]
        out_net = self.network(attributes)
        prob_cat = out_net["prob_cat"]
        latent = out_net["gaussian"]
        bin_tensor = prob_cat.argmax(-1)
        gd_bin_list, result_bin_list, non_labeled_id_list = summary_bin_list_from_batch(batch, bin_tensor)

        # Compute metrics.
        precision, recall, ARI, F1 = evaluate(
            gd_bin_list=gd_bin_list,
            result_bin_list=result_bin_list,
            non_labeled_id_list=non_labeled_id_list,
            unclassified=0)

        # plotting graph for visualization here.
        contig_id_list = [int(id) for index, id in enumerate(batch["id"])]
        plotting_contig_list = contig_id_list[:self.plot_graph_size]
        gd_ag_graph_path, result_ag_graph_path = log_ag_graph(
            plotting_graph_size=self.plot_graph_size,
            processed_zarr_dataset_path=self.processed_zarr_dataset_path,
            plotting_contig_list=plotting_contig_list,
            log_path=self.log_path,
            gd_bin_list=gd_bin_list,
            result_bin_list=result_bin_list,
        )
        gd_knn_graph_path, result_knn_graph_path = log_knn_graph(
            plotting_graph_size=self.plot_graph_size,
            plotting_contig_list=plotting_contig_list,
            k=self.k,
            batch=batch,
            log_path=self.log_path,
            gd_bin_list=gd_bin_list,
            result_bin_list=result_bin_list,
        )

        # add gmm to the latent vector, wrap a function here.
        if self.use_gmm:
            self.log_gmm(batch, latent)
        
        # Visualize latent space.
        result_tsne_figure_path = log_tsne_figure(
            batch=batch,
            latent=latent,
            log_path=self.log_path,
        )

        self.log("val/acc", attributes.shape[0], on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/precision", precision, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/recall", recall, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/F1", F1, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/ARI", ARI, on_step=False, on_epoch=True, prog_bar=False)
        wandb.log({"val/ground_truth_ag_subgraph": wandb.Image(gd_ag_graph_path)})
        wandb.log({"val/result_ag_subgraph": wandb.Image(result_ag_graph_path)})
        wandb.log({"val/ground_truth_knn_subgraph": wandb.Image(gd_knn_graph_path)})
        wandb.log({"val/result_knn_subgraph": wandb.Image(result_knn_graph_path)})
        wandb.log({"val/tsne_figure": wandb.Image(result_tsne_figure_path)})

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.lr)```

Not sure what exactly you wanna do, but have a look at the available hooks here:

https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/core/hooks.html

There is also something like on_pretrain_routine_start

Thanks so much! That helps me a lot.