Error running on TPU google colab

class TrialModel(pl.LightningModule):

    def __init__(self, cfg: Dict, pretrained=True):
        self.backbone = resnet18(pretrained=True, progress=True)

    def forward(self, x):
        return pred, confidences

    def training_step(self, batch, batch_idx):
        targets = torch.tensor(batch["target_positions"], device=self.device)
        data = torch.tensor(batch["image"], device=self.device)

        outputs,confidence = self(data)
        loss = pytorch_neg_multi_log_likelihood_batch(targets, outputs, confidence)
        pbar ={'train_loss':loss}
        return {'loss':loss,'progress_bar':pbar}

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)

# data loader

os.environ["DATA_FOLDER"] = DIR_INPUT
dm = LocalDataManager(None)

train_cfg = cfg["train_data_loader"]

rasterizer = build_rasterizer(cfg, dm) #cfg is predefined dictionary

train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open()
train_dataset = AgentDataset(cfg, train_zarr, rasterizer,min_frame_future=10)
train_dataloader = DataLoader(train_dataset,

model = TrialtModel(cfg, pretrained=False)
trainer = pl.Trainer(tpu_cores=8, max_steps=500)#,,train_dataloader)


GPU available: False, used: False
TPU available: True, using: 8 TPU cores
training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None

  | Name     | Type       | Params
0 | backbone | ResNet     | 11 M  
1 | head     | Sequential | 2 M   
2 | logit    | Linear     | 1 M   
Epoch 0: 0%
0/88561 [00:00<?, ?it/s]
Exception                                 Traceback (most recent call last)
<ipython-input-38-0766566ac519> in <module>()
     24 trainer = pl.Trainer(tpu_cores=8, max_steps=500)#,
---> 26,train_dataloader)

5 frames
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/ in wrapped_fn(self, *args, **kwargs)
     46             if entering is not None:
     47                 self.state = entering
---> 48             result = fn(self, *args, **kwargs)
     50             # The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/ in fit(self, model, train_dataloader, val_dataloaders, datamodule)
   1076             self.accelerator_backend = TPUBackend(self)
   1077             self.accelerator_backend.setup()
-> 1078             self.accelerator_backend.train(model)
   1079             self.accelerator_backend.teardown(model)

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/ in train(self, model)
     85                 args=(model, self.trainer, self.mp_queue),
     86                 nprocs=self.trainer.tpu_cores,
---> 87                 start_method=self.start_method
     88             )

/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/ in spawn(fn, args, nprocs, join, daemon, start_method)
    393         join=join,
    394         daemon=daemon,
--> 395         start_method=start_method)

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/ in start_processes(fn, args, nprocs, join, daemon, start_method)
    156     # Loop on join until it returns True or raises an exception.
--> 157     while not context.join():
    158         pass

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/ in join(self, timeout)
    105                 raise Exception(
    106                     "process %d terminated with signal %s" %
--> 107                     (error_index, name)
    108                 )
    109             else:

Exception: process 6 terminated with signal SIGKILL

I am getting the above error while running on tpu. There seems to be some issue with (start_method of xmp.spawn torch mutiprocessing ) data loading. It runs without any error while running on a single gpu. I am loading the data correctly? or any other issue?


According to

There are cases in which it is NOT possible to use DDP. Examples are:

  • Jupyter Notebook, Google COLAB, Kaggle, etc.
  • You have a nested script without a root package
  • Your script needs to invoke both .fit and .test, or one of them multiple times

In these situations you should use dp or ddp_spawn instead.

If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used.

Can you try settingdistributed_backend = 'ddp_spawn explicitly? (Or accelerator, as it is called now)

@sujay_khandekar could you share the notebook?

I think this error due to memory used fully in colab. Your batch_size=24 and your using 8 cores, total effective batch_size in tpu calculated to 24*8, which is too much for colab to handle. Your problem will be solved if you use <<24.