Training on combined two large dataset with 16bit precision

Hi, I am facing a weird issue on lightening 0.8.5. When I train my model with 1064708 samples it is training fine, and in other scenario when I train with 148642 samples, it also trains good. But when I combine these two datasets I ran in to nan values during validation only after 2/3 epochs, training loss seems to go fine without nans. When I dig deeper and saw the model weights I observed that only batch_norm running mean and running var weights are becoming nans, with rest seems fine. Is there any way to debug this issue? seems very weird to me. All experiments are carried with mixed precision 01 and 16

maybe try gradient clipping? but this is not a PL issue
this is an issue about using 16 bit for research which is inherently unstable

looks like it is indeed an issue with amp 16. Is it unstable with newly included native amp in pytorch?
btw, native AMP doesn’t use any of the Ox options