NN output within a numba jitted function


I have a jitted function within which I need to use the output of a neural network (trained using PyTorch Lightning). The pseudo code will make this clearer:

while True:
x = sample_from_model() ← numpy type, hence compatible with numba
out = NN(torch.Tensor(x)) ← incompatible with numba

Is there a way to circumvent this problem? First thing that comes to mind is to manually extract the weights and compute the forward pass.

Thanks in advance,