I have a model, that uses gradient checkpointing and ddp. It works fine, when I train it on a single gpu. It also works fine if I turn off checkpointing. However with multiple GPUs loss initially looks innocent, but then suddenly becomes NaN
:
checkpointing | no checkpointing | |
---|---|---|
gpus = 1 | works | works |
gpus = 4 | fails | works |
The only part of the model that uses checkpointing is:
class MergeLayer(nn.Module):
...
def apply_forward(self, inputs):
x = torch.cat(inputs, 1)
assert x.size(1) == self.in_channels
x = F.leaky_relu(x)
x = self.conv(x)
x = F.leaky_relu(x)
assert x.size(1) == self.out_channels
return x
def _apply_forward_splat(self, *inputs):
# checkpointing does not like to consume list
return self.apply_forward(inputs)
def forward(self, inputs):
assert total_channels(inputs) == self.in_channels
if self.save_memory and any_requires_grad(inputs):
x = checkpoint(self._apply_forward_splat, *inputs)
else:
x = self.apply_forward(inputs)
assert x.size(1) == self.out_channels
return x
Any ideas?