RAM Crash after validation

After training and validation for 1st epoch, my colab runtime crashes with RAM-memory full notification. I am searching of ways to reduce but nothing works, here’s my code for the lightning module

class Bert(pl.LightningModule):

def __init__(self):
    super().__init__()
    self.bert = BertModel.from_pretrained("/content/bert-base-uncased-hate")
    self.bert.trainable = False
    self.vit = timm.models.vit_base_patch16_224_in21k(pretrained=True,num_classes = 0)
    self.bFc1 = nn.Bilinear(768,768,512)
    self.classifier = nn.Sequential(nn.BatchNorm1d(512),
                                    nn.Linear(512,1))

def forward(self, input_ids, images):
    img = self.vit(images)
    text = self.bert(input_ids=input_ids)
    repr = self.bFc1(img,text['pooler_output'])
    return self.classifier(repr)

def training_step(self, batch ,batch_idx):
    input_ids = batch["input_ids"]
    images = batch["images"]
    labels = batch['labels']
    outputs = self(input_ids=input_ids, images=images)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs,labels)
    acc = ((outputs>0.5).int()==labels).sum()/labels.size()[0]s
    return {"loss": loss}

def validation_step(self, batch,batch_idx):
    input_ids = batch["input_ids"]
    images = batch["images"]
    labels = batch['labels']
    outputs = self(input_ids=input_ids, images=images)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs,labels)
    acc = ((outputs>0.5).int()==labels).sum()/labels.size()[0]
    
    self.log("loss",loss)
    self.log("accuracy",acc)

def configure_optimizers(self):
    return AdamW(self.parameters(), lr=1e-5)

model = Bert()

And here’s my trainer:
trainer = pl.Trainer(gpus=1,max_epochs=10,accumulate_grad_batches=4,num_sanity_val_steps=10)
trainer.fit(model, train_loader, val_loader)

Please tell me what’s the fault in the code that I’m going through