Monitored metric from ModelCheckpoint does not match metric from val_epoch_end

Hello! Using the code below, I get different metric results for “val_accuracy”. The monitored metric in ModelCheckpoint callback does some reduce function different from simple average (as done in val_epoch_end).

class BIGRU(LightningModule):

    def __init__(self, transformer, rnn_hiddenSize, numClasses, lr=None, weight_decay=None):
        super().__init__()

        self.lr = lr if lr is not None else 0.01
        self.weight_decay = weight_decay if weight_decay is not None else 0.1
        self.loss_func = torch.nn.CrossEntropyLoss()
        #self.torchDevice = device
        print("Running on", self.device)

        # metrics
        self.metricsMethods = BaseMetrics(["text", "signature"], [])
        self.modelMetrics = {'accuracy': self.metricsMethods.getAccuracy, 'f1score': self.metricsMethods.get_f1Score}

        # define encoder transformer model
        self.encoderTokenizer = AutoTokenizer.from_pretrained(transformer)
        self.encoderModel = AutoModel.from_pretrained(transformer)

        # define rnn
        encoder_embedding_size = self.encoderModel.config.hidden_size
        self.rnn = nn.GRU(encoder_embedding_size, rnn_hiddenSize, batch_first=True, bidirectional=True)

        # define fully connected layers
        self.classify = torch.nn.Linear(rnn_hiddenSize*2, numClasses)
        torch.nn.init.xavier_uniform_(self.classify.weight)

    def forward(self, sentences):

        # embedding
        encodedSentences = self.encoderTokenizer(sentences, truncation=True, padding=True, return_tensors="pt")
        encodedSentences = {key:value.to(self.device) for (key,value) in encodedSentences.items()} # convert tensors to device
        with torch.no_grad():
            encoder_output = self.encoderModel(**encodedSentences)
        sentences_embeddings = self.encoderMeanPooling(encoder_output, encodedSentences['attention_mask'])

        # rnn
        sentences_embeddings = sentences_embeddings.unsqueeze(dim=0)
        outputRnn, _ = self.rnn(sentences_embeddings)

        # classify
        outputRnn = outputRnn.view(len(sentences), -1) #probably does not work with batch size > 1. You should flat before giving this to the loss function
        classification = self.classify(outputRnn)
        return classification
    
    #Mean Pooling - Take attention mask into account for correct averaging
    def encoderMeanPooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def configure_optimizers(self):
        optim = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return optim

    
    def training_step(self, batch, batch_idx):
        x, y = self._preprocessBatch(batch)
        preds = self(x)
        loss = self.loss_func(preds, y)
        
        metrics = self._calculateMetrics(preds, y, phase="train")
        metrics["loss"] = loss
        self.log_dict(metrics, on_step=True, on_epoch=True)

        return metrics

    def validation_step(self, batch, batch_idx):
        x, y = self._preprocessBatch(batch)
        preds = self(x)
        loss = self.loss_func(preds, y)
        
        metrics = self._calculateMetrics(preds, y, phase="val")
        metrics["val_loss"] = loss
        self.log_dict(metrics)  

        return metrics

    def validation_epoch_end(self, outputs):
        metricsAvg = {}
        for key in outputs[0].keys():
            metricsAvg[key] = torch.stack([x[key].float() for x in outputs]).mean()

        print(metricsAvg)
        return metricsAvg

    def _preprocessBatch(self, batch):
        x,y = batch

        # fix dataloader items (dataloader transform list of string in list of tuples)
        x = [item[0] for item in x]
        y = torch.tensor(y)

        return x, y

    def _calculateMetrics(self, predictions, labels, phase):
        metrics = {}

        bestProbabilities, bestPredictions  = self._getBestPredictions(predictions)
        for metricName, functionToCall in self.modelMetrics.items():
            metricValue = torch.tensor(functionToCall(bestPredictions, labels.cpu()))
            metrics[phase + '_' + metricName] = metricValue
        
        return metrics

    def _getBestPredictions(self, predictions: torch.Tensor) -> Tuple[np.ndarray,np.ndarray]:
        probabilities = torch.softmax(predictions, 1)
        bestProbabilities, bestPredictions = torch.max(probabilities, 1)
        bestProbabilities = bestProbabilities.detach().cpu().numpy()
        bestPredictions = bestPredictions.detach().cpu().numpy()

        return bestProbabilities, bestPredictions

wandb_logger = WandbLogger(project="emailPreprocessing", log_model=False)
wandb_logger.watch(model, log='all', log_freq=100)
#wandb_logger.log_hyperparams(hparams)

# monitored metric is different from val_epoch_end. Maybe because torch lightning applies weighted average on each batch (like they do for the loss)
checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode="max", verbose=True)
early_stop_callback = EarlyStopping(monitor='val_loss', mode="min", patience=5, verbose=True)

gpus = 0
epochs = 3
grad_accumulation_batches = 16

trainer = Trainer(gpus=gpus,
            max_epochs=epochs, 
            accumulate_grad_batches=grad_accumulation_batches,
            # gradient_clip_val=1.0,
            # amp_backend='apex',
            # amp_level='02',
            precision=32,
            logger=wandb_logger,
            checkpoint_callback=True,
            callbacks=[early_stop_callback, checkpoint_callback])

trainer.fit(model, trainDataloader, testDataloader)

RESULTS

Val_epoch_end print: {‘val_accuracy’: tensor(0.8599), ‘val_f1score’: tensor(0.8599), ‘val_loss’: tensor(0.3385)}

ModelCheckpoint verbose: Epoch 1, global step 9: val_accuracy reached 0.96450 (best 0.96450), saving model to c:\wandb\run-20210621_183543-13eb6ofv\files\emailPreprocessing\13eb6ofv\checkpoints\epoch=1-step=9.ckpt" as top 1



In this epoch I showed, the difference was really huge. Most of time, they are similar, but never equal. Besides that, when I log these results to WandB, the logging is exactly the same as ModelCheckpoint verbose.

As discussed here, I imagine that some weighted average over my batch samples is being applied, but since I expect variable sequences as input in the batch, this average is not helping me to get the actual best model.

Any clarification regarding this result difference is appreciated! Thanks!