AttributeError: 'SSD' object has no attribute 'detect'

I am trying to convert my PyTorch code to PyTorch lightning but I am getting an error.
In init, I compute the detect by Detect but it is not recognized (?) in the forward module. Do I need to include the Detect inside of the lightning module somehow?

import pytorch_lightning as pl 
from utils.make import make_vgg, make_extras, make_loc_conf
from utils.others import L2Norm, DBox, Detect
from utils.loss import MultiBoxLoss

class SSD(pl.LightningModule):
    
    def __init__(self, hparams, phase, cfg): 
        super(SSD, self).__init__()
        
        self.save_hyperparameters(hparams)
        self.phase = phase
        self.num_classes = cfg["num_classes"]
        
        self.vgg = make_vgg()
        self.extras = make_extras()
        self.L2Norm = L2Norm()
        self.loc, self.conf = make_loc_conf(
        cfg["num_classes"], cfg["bbox_aspect_num"]
        )
        
        dbox = DBox(cfg)
        self.dbox_list = dbox.make_dbox_list()
        
        if self.phase == 'inference':
            self.detect = Detect.apply 
    
    def loss_function(self, outputs, targets): 
        criterion = MultiBoxLoss(jaccard_thresh=0.5, neg_pos = 3) 
        # Loss function = loss_l(position of bbox) + loss_c(classification)
        print('Type of output is: ')
        print(type(outputs))
        
        loss_l, loss_c = criterion(outputs, targets)
        return loss_l + loss_c
    
    def forward(self, x):
        sources = list()
        loc = list()
        conf = list()
        
        for k in range(23):
            x = self.vgg[k](x)
            
        source1 = self.L2Norm(x)
        sources.append(source1)
        
        for k in range(23,len(self.vgg)):
            x = self.vgg[k](x)
            
        sources.append(x)
        
        for k, v in enumerate(self.extras):
            x = F.relu(v(x),inplace=True)
            if k%2==1:
                sources.append(x)
                
        for (x,l,c) in zip(sources,self.loc,self.conf):
            loc.append(l(x).permute(0,2,3,1).contiguous())
            conf.append(c(x).permute(0,2,3,1).contiguous())
        
        loc = torch.cat([o.view(o.size(0),-1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0),-1) for o in conf], 1)
            
        loc = loc.view(loc.size(0), -1, 4)
        conf = conf.view(conf.size(0), -1, self.num_classes)
            
        output = (loc, conf, self.dbox_list)
            
        if self.phase == "inference":
            return self.detect(output[0],output[1],output[2])
        else:
            return output

The class Detect is defined like this:

class Detect(Function):
    
    def __init__(self,conf_thresh=0.01, top_k = 200, nms_thresh=0.45):
        self.softmax = nn.Softmax(dim=-1)
        self.conf_thresh = conf_thresh
        self.top_k = top_k
        self.nms_thresh = nms_thresh # uncommented by Daigo 
    
    def forward(self, loc_data, conf_data, dbox_list):
        num_batch = loc_data.size(0)
        num_dbox = loc_data.size(1)
        num_classes = conf_data.size(2)
        
        conf_data = self.softmax(conf_data)
        
        output = torch.zeros(num_batch, num_classes, self.top_k, 5)
        
        conf_preds = conf_data.transpose(2,1)
        
        for i in range(num_batch):
            decoded_boxes = decode(loc_data[i], dbox_list)
            
            conf_scores = conf_preds[i].clone()
            
            for cl in range(1,num_classes):
            
                c_mask = conf_scores[cl].gt(self.conf_thresh)
            
                scores = conf_scores[cl][c_mask]
            
                if scores.nelement()==0:
                    continue
                
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
            
                boxes = decoded_boxes[l_mask].view(-1,4)
            
                ids, count = nm_suppression(
                    boxes, scores, self.nms_thresh, self.top_k 
                )
            
                output[i,cl,:count] = torch.cat((scores[ids[:count]].unsqueeze(1), boxes[ids[:count]]), 1)
        
        return output

I got the following error:

File “C:\Users\user\Object Detection - spyder\script_Sandia_0608_lightning\ssd_light.py”, line 226, in
trainer.fit(ssd)

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 458, in fit
self._run(model)

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 756, in _run
self.dispatch()

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 797, in dispatch
self.accelerator.start_training(self)

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\accelerators\accelerator.py”, line 96, in start_training
self.training_type_plugin.start_training(trainer)

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py”, line 144, in start_training
self._results = trainer.run_stage()

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 807, in run_stage
return self.run_train()

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 842, in run_train
self.run_sanity_check(self.lightning_module)

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 1107, in run_sanity_check
self.run_evaluation()

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 962, in run_evaluation
output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py”, line 174, in evaluation_step
output = self.trainer.accelerator.validation_step(args)

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\accelerators\accelerator.py”, line 226, in validation_step
return self.training_type_plugin.validation_step(*args)

File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py”, line 161, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)

File “C:\Users\user\Object Detection - spyder\script_Sandia_0608_lightning\ssd_light.py”, line 136, in validation_step
outputs = self(images)

File “C:\Users\user\anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 889, in _call_impl
result = self.forward(*input, **kwargs)

File “C:\Users\user\Object Detection - spyder\script_Sandia_0608_lightning\ssd_light.py”, line 116, in forward
return self.detect(output[0],output[1],output[2])

File “C:\Users\user\anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 947, in getattr
raise AttributeError("’{}’ object has no attribute ‘{}’".format(

AttributeError: ‘SSD’ object has no attribute ‘detect’