Torch.utils.checkpoint not compatible with Mixed Precision

I just migrated to Pytorch Lightning today, I use torch.utils.checkpoint (not to be confused with model saving) for two or three block of Resnet in order to reduce memory overhaul.

The problem appears only when I activate Mixed Precision, precision=16 (this problem does not happen on Pytorch).
I debuged the model, and the error appears only when the activation maps are fed to the checkpointed chunk.
It seems that Pytorch Lightning is not casting the checkpointed weights to hald precision.
This is the error I get :

return F.conv2d(input, weight, self.bias, self.stride,
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

I am using the last version of Lightning 1.1.7, PyTorch 1.7.1, cuda 11, on Ubuntu 20.04.
Half Precision works fine without checkpointing, and checkpointing works fine without half precision.

Hello, my apology for the late reply. We are slowly converging to deprecate this forum in favor of the GH build-in version… Could we kindly ask you to recreate your question there - Lightning Discussions