Gradient checkpointing + ddp = NaN

I have a model, that uses gradient checkpointing and ddp. It works fine, when I train it on a single gpu. It also works fine if I turn off checkpointing. However with multiple GPUs loss initially looks innocent, but then suddenly becomes NaN:

checkpointing no checkpointing
gpus = 1 works works
gpus = 4 fails works

The only part of the model that uses checkpointing is:

class MergeLayer(nn.Module):
        ...

    def apply_forward(self, inputs):
        x = torch.cat(inputs, 1)
        assert x.size(1) == self.in_channels
        x = F.leaky_relu(x)
        x = self.conv(x)
        x = F.leaky_relu(x)
        assert x.size(1) == self.out_channels
        return x

    def _apply_forward_splat(self, *inputs):
        # checkpointing does not like to consume list
        return self.apply_forward(inputs)


    def forward(self, inputs):
        assert total_channels(inputs) == self.in_channels
        if self.save_memory and any_requires_grad(inputs):
            x = checkpoint(self._apply_forward_splat, *inputs)
        else:
            x = self.apply_forward(inputs)
        assert x.size(1) == self.out_channels
        return x

Any ideas?

Hi, I am quite suspicious of what the checkpoint(...) does, mind share a full example to reproduce? Eventually, maybe open an issue on PL and link it here…

I will try to reduce the example and then post it.

Hi @jw3126,

I’ve observed training freeze behavior when using DDP, gradient checkpointing and SyncBatchNorm. By following this solution: Training gets stuck when using SyncBN · Issue #105 · NVIDIA/apex · GitHub, the training won’t freeze but the loss after first iteration (sometimes after several iterations) becomes NaN.

I am wondering have you solved this issue? Or is there anything interesting you’ve discovered? Thanks!