Error when changing gpus=1 --> gpus=4 on AWS Sagemaker

On a 4GPU instance on AWS Sagemaker (Pytorch 1.8 kernel), the following code installs pytorch_lightning 1.5.0. With gpus=1, it runs. With gpus=4, it crashes (output below code). Furthermore, it seems to hang the kernel. Do you know the cause of this or how I can use Pytorch Lightning with Sagemaker?

# install pytorch_lightning:
import sys
!{sys.executable} -m pip install pytorch_lightning
import pytorch_lightning as pl

# generate data:
import torch
from torch.utils.data import DataLoader, TensorDataset

n=2**13
bs = 16
x = torch.rand(n, 4)

labels = torch.tensor([ torch.norm(xv)<=1. for xv in x ], dtype=torch.float)
print(labels)

dl_train = DataLoader(TensorDataset(x, labels), batch_size=bs, shuffle=False, num_workers=4)

# loss and regressor:
loss_func = torch.nn.MSELoss()

class Regressor(pl.LightningModule):
    def __init__(self):
        super().__init__()
        #self.save_hyperparameters()
        
        self.lin_hidden = torch.nn.Linear(4, 8)
        self.lin_out = torch.nn.Linear(8, 1)
    def forward(self, data):
        return self.lin_out(torch.tanh(self.lin_hidden(data)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = loss_func(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = loss_func(y_hat, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        print('Val', avg_loss)
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.00001)
    
from pytorch_lightning import Trainer, seed_everything
seed_everything(0)

# model
model = Regressor()
trainer = Trainer(accelerator='gpu', gpus=4, max_epochs=5, strategy='dp')
trainer.fit(model, dl_train, dl_train)

RuntimeError Traceback (most recent call last)
in
57 model = Regressor()
58 trainer = Trainer(accelerator=‘gpu’, gpus=4, max_epochs=5, strategy=‘dp’)
—> 59 trainer.fit(model, dl_train, dl_train)

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
734 train_dataloaders = train_dataloader
735 self._call_and_handle_interrupt(
→ 736 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
737 )
738

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
680 “”"
681 try:
→ 682 return trainer_fn(*args, **kwargs)
683 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
684 except KeyboardInterrupt as exception:

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
768 # TODO: ckpt_path only in v1.7
769 ckpt_path = ckpt_path or self.resume_from_checkpoint
→ 770 self._run(model, ckpt_path=ckpt_path)
771
772 assert self.state.stopped

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1191
1192 # dispatch start_training or start_evaluating or start_predicting
→ 1193 self._dispatch()
1194
1195 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
1270 self.training_type_plugin.start_predicting(self)
1271 else:
→ 1272 self.training_type_plugin.start_training(self)
1273
1274 def run_stage(self):

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
200 def start_training(self, trainer: “pl.Trainer”) → None:
201 # double dispatch to initiate the training loop
→ 202 self._results = trainer.run_stage()
203
204 def start_evaluating(self, trainer: “pl.Trainer”) → None:

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
1280 if self.predicting:
1281 return self._run_predict()
→ 1282 return self._run_train()
1283
1284 def _pre_training_routine(self):

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
1302 self.progress_bar_callback.disable()
1303
→ 1304 self._run_sanity_check(self.lightning_module)
1305
1306 # enable train mode

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self, ref_model)
1366 # run eval step
1367 with torch.no_grad():
→ 1368 self._evaluation_loop.run()
1369
1370 self.call_hook(“on_sanity_check_end”)

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
→ 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
107 dl_max_batches = self._max_batches[dataloader_idx]
108
→ 109 dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
110
111 # store batch level output per dataloader

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
→ 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in advance(self, data_fetcher, dataloader_idx, dl_max_batches, num_dataloaders)
121 # lightning module methods
122 with self.trainer.profiler.profile(“evaluation_step_and_end”):
→ 123 output = self._evaluation_step(batch, batch_idx, dataloader_idx)
124 output = self._evaluation_step_end(output)
125

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in _evaluation_step(self, batch, batch_idx, dataloader_idx)
213 self.trainer.lightning_module._current_fx_name = “validation_step”
214 with self.trainer.profiler.profile(“validation_step”):
→ 215 output = self.trainer.accelerator.validation_step(step_kwargs)
216
217 return output

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, step_kwargs)
234 “”"
235 with self.precision_plugin.val_step_context():
→ 236 return self.training_type_plugin.validation_step(*step_kwargs.values())
237
238 def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) → Optional[STEP_OUTPUT]:

/opt/conda/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/dp.py in validation_step(self, *args, **kwargs)
102
103 def validation_step(self, *args, **kwargs):
→ 104 return self.model(*args, **kwargs)
105
106 def test_step(self, *args, **kwargs):

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
916 result = self._slow_forward(*input, **kwargs)
917 else:
→ 918 result = self.forward(*input, **kwargs)
919 for hook in itertools.chain(
920 _global_forward_hooks.values(),

/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
164 if len(self.device_ids) == 1:
165 return self.module(*inputs[0], **kwargs[0])
→ 166 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
167 outputs = self.parallel_apply(replicas, inputs, kwargs)
168 return self.gather(outputs, self.output_device)

/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in replicate(self, module, device_ids)
169
170 def replicate(self, module, device_ids):
→ 171 return replicate(module, device_ids, not torch.is_grad_enabled())
172
173 def scatter(self, inputs, kwargs, device_ids):

/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in replicate(network, devices, detach)
89 params = list(network.parameters())
90 param_indices = {param: idx for idx, param in enumerate(params)}
—> 91 param_copies = _broadcast_coalesced_reshape(params, devices, detach)
92
93 buffers = list(network.buffers())

/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in _broadcast_coalesced_reshape(tensors, devices, detach)
65 from ._functions import Broadcast
66 if detach:
—> 67 return comm.broadcast_coalesced(tensors, devices)
68 else:
69 # Use the autograd function to broadcast if not detach

/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/comm.py in broadcast_coalesced(tensors, devices, buffer_size)
56 devices = [_get_device_index(d) for d in devices]
57 tensors = [_handle_complex(t) for t in tensors]
—> 58 return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
59
60

RuntimeError: NCCL Error 2: unhandled system error