Gradient checkpointing + ddp = NaN

I have a model, that uses gradient checkpointing and ddp. It works fine, when I train it on a single gpu. It also works fine if I turn off checkpointing. However with multiple GPUs loss initially looks innocent, but then suddenly becomes NaN:

checkpointing no checkpointing
gpus = 1 works works
gpus = 4 fails works

The only part of the model that uses checkpointing is:

class MergeLayer(nn.Module):

    def apply_forward(self, inputs):
        x =, 1)
        assert x.size(1) == self.in_channels
        x = F.leaky_relu(x)
        x = self.conv(x)
        x = F.leaky_relu(x)
        assert x.size(1) == self.out_channels
        return x

    def _apply_forward_splat(self, *inputs):
        # checkpointing does not like to consume list
        return self.apply_forward(inputs)

    def forward(self, inputs):
        assert total_channels(inputs) == self.in_channels
        if self.save_memory and any_requires_grad(inputs):
            x = checkpoint(self._apply_forward_splat, *inputs)
            x = self.apply_forward(inputs)
        assert x.size(1) == self.out_channels
        return x

Any ideas?

Hi, I am quite suspicious of what the checkpoint(...) does, mind share a full example to reproduce? Eventually, maybe open an issue on PL and link it here…

I will try to reduce the example and then post it.