Am I not validating and testing my data correctly?

I have train, validation and test loader objects.

I wrote this lightning module:

class GraphLevelGNN(pl.LightningModule):
    See #see; below is a standard set up for using pl lightning modules.
    def __init__(self,**model_kwargs):

       # Saving hyperparameters
        self.model = GraphGNNModel(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()
        self.optimizer_name = model_kwargs['optimizer_name']
        self.learning_rate = model_kwargs['learning_rate']

    def forward(self, data, mode="train"):
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)

        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            data.y = data.y.float()
            preds = x.argmax(dim=-1)

        loss = self.loss_module(x, data.y.float())
        acc = (preds == data.y).sum().float() / preds.shape[0]

        data.y =
        preds =

        f1 = BinaryF1Score().to(device) #change this to other F1, precision etc with num_classes
        f1_score = f1(preds,data.y).to(device)

        precision = BinaryPrecision().to(device)

        recall = BinaryRecall().to(device)

        return loss, acc, f1_score,precision_score, recall_score,preds

    def configure_optimizers(self):
        learning_rate = self.learning_rate
        optimizer = optim.SGD(self.parameters(),lr=learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc, _,_,_,_ = self.forward(batch, mode="train")
        self.log('train_loss', loss,on_epoch=True,logger=True,batch_size=64)
        self.log('train_acc', acc,on_epoch=True,logger=True,batch_size=64)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc, _,_,_,_ = self.forward(batch, mode="val")
        self.log('val_acc', acc,on_epoch=True,logger=True,batch_size=64)
        self.log('val_loss', loss,on_epoch=True,logger=True,batch_size=64)

    def test_step(self, batch, batch_idx):
        loss,acc, f1,precision, recall,preds = self.forward(batch, mode="test")
        self.log('test_acc', acc,on_epoch=True,logger=True,batch_size=64)
        self.log('test_f1', f1,on_epoch=True,logger=True,batch_size=64)
        self.log('test_precision', precision,on_epoch=True,logger=True,batch_size=64)
        self.log('test_recall', recall,on_epoch=True,logger=True,batch_size=64)

To calculate validation and test metrics, I wrote this:

def evaluate_model(model,graph_test_loader,graph_val_loader,output_file='/home/output.txt'):
    Aim: Return validation and test metrics for the best model returned by ray tune
    Input: Validation and test set and model
    Output: File with metrics for validation and test set.

    How is this different from parse_logger_file(); this function is not per epoch, and is also incorporating the test set

    model.eval() #switch for evaluating, e.g. turning off dropout, batch norm etc
    trainer = pl.Trainer()#(accelerator='gpu',devices=-1)
    test_result = trainer.test(model, graph_test_loader, verbose=False)[0] #[0] because it returns the dict in a list
    validation_result = trainer.test(model, graph_val_loader, verbose=False)[0]

    output_file = open(output_file, 'a')
    output_file.write('validation results' + '\n')
    for i in validation_result:
        output_file.write(i + '\t' + str(validation_result[i]) + '\n')

    output_file.write('******' + '\n')

    output_file.write('test results' + '\n')
    for i in test_result:
        output_file.write(i + '\t' + str(test_result[i]) + '\n')


My output returns high metrics for my validation set (i.e. precision, recall etc >0.89) but very low metrics for my test set (<0.5, for multiple different data sets).

I’m wondering have I not implemented the calculation of validation and test metrics correctly - could someone show me how to alter the evaluate_model function to correctly return overall precision, recall and accuracy metrics for the validation and test sets?