Hi there, I have finished my codebase using the Pytorch Lightning module. I understood that the LightningModule is divided into steps like training_step, test_step, and validation_step, my question is how and where to add the pretrain step, with target epochs. Here is my code piece, thanks so much ahead!
class GMVAEModel(pl.LightningModule):
def __init__(
self,
input_size=104,
gaussian_size=10,
num_classes=10,
lr=0.0001,
w_cat=None,
w_gauss=None,
w_rec=None,
rec_type="mse",
processed_zarr_dataset_path=None,
plot_graph_size=200,
log_path="",
k=15,
use_gmm=True,
*args,
**kwargs,
):
"""GMVAEModel, inheriting the LightningModule, need to implement
the training_step, validation_step, test_step and optimizers.
Args:
input_size (int): size of input vector, 104 in current settings.
gaussian_size (int): size of latent space vector.
num_classes (int): size of the gaussian mixture.
lr (float): learning rate of the models.
k (int): k neighbors used when plotting KNN graph.
w_cat (float): categorical loss weights.
w_gaussian (float): gaussian loss weights.
w_rec (float): reconstruction loss weights.
rec_type (string): reconstruction loss type, supporting 'mse', 'bce' currently.
processed_zarr_dataset_path (string): path of processed zarr dataset,
for logging the ag graph.
plot_graph_size (int): size of logging graph in wandb.
log_path (string): path to save the logging result.
use_gmm (boolean): whether use gmm fit the latent vector z.
Attrs:
"""
super().__init__()
self.network = GMVAENet(
x_dim = input_size,
z_dim = gaussian_size,
y_dim = num_classes,
)
self.lr = lr
self.w_cat = w_cat
self.w_gauss = w_gauss
self.w_rec = w_rec
self.rec_type = rec_type
self.k = k
self.losses = LossFunctions()
self.processed_zarr_dataset_path = processed_zarr_dataset_path
self.plot_graph_size = plot_graph_size
self.log_path = log_path
self.use_gmm = use_gmm
self.gmm = GaussianMixture(n_components=num_classes) if self.use_gmm else None
def unlabeled_loss(self, data, out_net):
z, data_recon = out_net["gaussian"], out_net["x_rec"]
logits, prob_cat = out_net["logits"], out_net["prob_cat"]
y_mu, y_var = out_net["y_mean"], out_net["y_var"]
mu, var = out_net["mean"], out_net["var"]
loss_rec = self.losses.reconstruction_loss(data, data_recon, rec_type=self.rec_type)
loss_gauss = self.losses.gaussian_loss(z, mu, var, y_mu, y_var)
loss_cat = -self.losses.entropy(logits, prob_cat) - np.log(0.1)
loss_total = self.w_rec * loss_rec + self.w_gauss * loss_gauss + self.w_cat * loss_cat
predicted_clusters = prob_cat.argmax(-1)
highest_probs = prob_cat.max(-1).values
loss_dict = {
"total": loss_total,
"predicted_clusters": predicted_clusters,
"reconstruction": loss_rec * self.w_rec,
"gaussian": loss_gauss,
"categorical": loss_cat,
"highest_prob": highest_probs,
}
return loss_dict
def forward(self):
pass
def training_step(self, batch, batch_idx):
attributes = batch["feature"]
out_net = self.network(attributes)
loss_dict = self.unlabeled_loss(attributes, out_net)
loss = loss_dict["total"]
reconstruction_loss = loss_dict["reconstruction"]
gaussian_loss = loss_dict["gaussian"]
categorical_loss = loss_dict["categorical"]
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("train/reconstruction_loss", reconstruction_loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("train/gaussian_loss", gaussian_loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("train/categorical_loss", categorical_loss, on_step=False, on_epoch=True, prog_bar=False)
return {"loss": loss}
def test_step(self):
pass
def validation_step(self, batch, batch_idx):
attributes = batch["feature"]
out_net = self.network(attributes)
prob_cat = out_net["prob_cat"]
latent = out_net["gaussian"]
bin_tensor = prob_cat.argmax(-1)
gd_bin_list, result_bin_list, non_labeled_id_list = summary_bin_list_from_batch(batch, bin_tensor)
# Compute metrics.
precision, recall, ARI, F1 = evaluate(
gd_bin_list=gd_bin_list,
result_bin_list=result_bin_list,
non_labeled_id_list=non_labeled_id_list,
unclassified=0)
# plotting graph for visualization here.
contig_id_list = [int(id) for index, id in enumerate(batch["id"])]
plotting_contig_list = contig_id_list[:self.plot_graph_size]
gd_ag_graph_path, result_ag_graph_path = log_ag_graph(
plotting_graph_size=self.plot_graph_size,
processed_zarr_dataset_path=self.processed_zarr_dataset_path,
plotting_contig_list=plotting_contig_list,
log_path=self.log_path,
gd_bin_list=gd_bin_list,
result_bin_list=result_bin_list,
)
gd_knn_graph_path, result_knn_graph_path = log_knn_graph(
plotting_graph_size=self.plot_graph_size,
plotting_contig_list=plotting_contig_list,
k=self.k,
batch=batch,
log_path=self.log_path,
gd_bin_list=gd_bin_list,
result_bin_list=result_bin_list,
)
# add gmm to the latent vector, wrap a function here.
if self.use_gmm:
self.log_gmm(batch, latent)
# Visualize latent space.
result_tsne_figure_path = log_tsne_figure(
batch=batch,
latent=latent,
log_path=self.log_path,
)
self.log("val/acc", attributes.shape[0], on_step=False, on_epoch=True, prog_bar=False)
self.log("val/precision", precision, on_step=False, on_epoch=True, prog_bar=False)
self.log("val/recall", recall, on_step=False, on_epoch=True, prog_bar=False)
self.log("val/F1", F1, on_step=False, on_epoch=True, prog_bar=False)
self.log("val/ARI", ARI, on_step=False, on_epoch=True, prog_bar=False)
wandb.log({"val/ground_truth_ag_subgraph": wandb.Image(gd_ag_graph_path)})
wandb.log({"val/result_ag_subgraph": wandb.Image(result_ag_graph_path)})
wandb.log({"val/ground_truth_knn_subgraph": wandb.Image(gd_knn_graph_path)})
wandb.log({"val/result_knn_subgraph": wandb.Image(result_knn_graph_path)})
wandb.log({"val/tsne_figure": wandb.Image(result_tsne_figure_path)})
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.lr)```