Multi-GPU Training fails on second execution Error: ProcessExitedException: process 0 terminated with signal SIGSEGV

I have encountered segmentation error on second time execution of my code. Moreover, it’s been observed that the below mentioned error occurs only if I’m running the code for second time i.e. first time it works, second time onwards it starts to fail. However, if I restart the kernel and run it again, it works perfectly. It seems like I have to restart the kernel before every time I run the code snippet, otherwise it’ll end up in segmentation fault error. Can you please help me to understand why ?

I’m running my code on Jupyter which is hosted on top of Kubernetes.

Python: 3.10.11
Cuda: 12.2
Torch: 2.0.1+cu117

Here’s the error

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name     | Type       | Params
----------------------------------------
0 | layer1   | Sequential | 728   
1 | layer2   | Sequential | 39.3 K
2 | drop_out | Dropout    | 0     
3 | fc1      | Linear     | 2.7 M 
4 | fc2      | Linear     | 10.0 K
----------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.180    Total estimated model params size (MB)
---------------------------------------------------------------------------
ProcessExitedException                    Traceback (most recent call last)
Cell In[2], line 64
     60     trainer.fit(model)
     63 if __name__ == "__main__":
---> 64     run()

Cell In[2], line 60, in run()
     57 trainer = Trainer(accelerator="gpu",
     58     devices=[0, 1], max_epochs=5)
     59 # Train the model
---> 60 trainer.fit(model)

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:543, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    541 self.state.status = TrainerStatus.RUNNING
    542 self.training = True
--> 543 call._call_and_handle_interrupt(
    544     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    545 )

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     41 try:
     42     if trainer.strategy.launcher is not None:
---> 43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py:144, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs)
    136 process_context = mp.start_processes(
    137     self._wrapping_function,
    138     args=process_args,
   (...)
    141     join=False,  # we will join ourselves to get the process references
    142 )
    143 self.procs = process_context.processes
--> 144 while not process_context.join():
    145     pass
    147 worker_output = return_queue.get()

File /opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:140, in ProcessContext.join(self, timeout)
    138 if exitcode < 0:
    139     name = signal.Signals(-exitcode).name
--> 140     raise ProcessExitedException(
    141         "process %d terminated with signal %s" %
    142         (error_index, name),
    143         error_index=error_index,
    144         error_pid=failed_process.pid,
    145         exit_code=exitcode,
    146         signal_name=name
    147     )
    148 else:
    149     raise ProcessExitedException(
    150         "process %d terminated with exit code %d" %
    151         (error_index, exitcode),
   (...)
    154         exit_code=exitcode
    155     )

ProcessExitedException: process 1 terminated with signal SIGSEGV

Here’s my code

import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from lightning.pytorch import LightningModule, Trainer
 
class LitCNN(LightningModule):
    def __init__(self):
        super().__init__()
        # Define your model architecture here
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 28, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(28, 56, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 56, 1000)
        self.fc2 = nn.Linear(1000, 10)
 
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        return self.fc2(out)
 
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss
 
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
 
    def train_dataloader(self):
        # Transformations applied on each image
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
 
        # Download training data from open datasets.
        dataset = MNIST(os.getcwd(), download=True, train=True, transform=transform)
        # DataLoader
        return DataLoader(dataset, batch_size=64, num_workers=4, shuffle=True)
 

def run():
    model = LitCNN()
    # Setup the trainer
    trainer = Trainer(accelerator="gpu",
        devices=[0, 1], max_epochs=5)
    # Train the model
    trainer.fit(model)


if __name__ == "__main__":
    run()