Gradient checkpointing + ddp = NaN

I have a model, that uses gradient checkpointing and ddp. It works fine, when I train it on a single gpu. It also works fine if I turn off checkpointing. However with multiple GPUs loss initially looks innocent, but then suddenly becomes NaN:

checkpointing no checkpointing
gpus = 1 works works
gpus = 4 fails works

The only part of the model that uses checkpointing is:

class MergeLayer(nn.Module):
        ...

    def apply_forward(self, inputs):
        x = torch.cat(inputs, 1)
        assert x.size(1) == self.in_channels
        x = F.leaky_relu(x)
        x = self.conv(x)
        x = F.leaky_relu(x)
        assert x.size(1) == self.out_channels
        return x

    def _apply_forward_splat(self, *inputs):
        # checkpointing does not like to consume list
        return self.apply_forward(inputs)


    def forward(self, inputs):
        assert total_channels(inputs) == self.in_channels
        if self.save_memory and any_requires_grad(inputs):
            x = checkpoint(self._apply_forward_splat, *inputs)
        else:
            x = self.apply_forward(inputs)
        assert x.size(1) == self.out_channels
        return x

Any ideas?

Hi, I am quite suspicious of what the checkpoint(...) does, mind share a full example to reproduce? Eventually, maybe open an issue on PL and link it here…

I will try to reduce the example and then post it.

https://github.com/PyTorchLightning/pytorch-lightning/issues/4788

Hi @jw3126,

I’ve observed training freeze behavior when using DDP, gradient checkpointing and SyncBatchNorm. By following this solution: Training gets stuck when using SyncBN · Issue #105 · NVIDIA/apex · GitHub, the training won’t freeze but the loss after first iteration (sometimes after several iterations) becomes NaN.

I am wondering have you solved this issue? Or is there anything interesting you’ve discovered? Thanks!

Not really. My “solution” was to use a single v100 instead of 4 k40 gpus.

Got it, thanks for the response!

For anyone stumbling on this, the issue is fixable in the Pytorch > =1.10 with the API call set_static_graph. To implement in PyLightning, one can do:

class CustomDDPPlugin(DDPPlugin):
    def configure_ddp(self):
        self.pre_configure_ddp()
        self._model = self._setup_model(LightningDistributedModule(self.model))
        self._register_ddp_hooks()
        self._model._set_static_graph() # THIS IS THE MAGIC LINE

Then call as usual

trainer = Trainer(gpus=4, strategy=CustomDDPPlugin())