Training works when using 1 TPU Core, but ProcessExitedException when try to use 8 cores

Hi all,

I’m Adam - high school sophomore and newbie to ML and Pytorch (and, to a lesser extent, computer programming in general) so please excuse any dumb things I may say hahah.

For the past few days, I’ve been trying to figure out how to use TPU on PyTorch Lightning - I’ve gotten closer, but I still have a major problem I don’t know the source of.

My latest issue is:

My program goes just fine when run it on 1 TPU Core, but when I try to run it on 8, I get the error process 0 terminated with exit code 17.

The code that triggers the error is:

from pytorch_lightning import Trainer, seed_everything
seed_everything(0)

# set the model 

model = EmpidModel()

trainer = Trainer(tpu_cores=8, max_epochs=2)
trainer.fit(model, trainDL, valDL)

More specifically, the trainer.fit line. Here’s the entire stack trace. It’s pretty monstrous.

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: You have set progress_bar_refresh_rate < 20 on Google Colab. This may crash. Consider using progress_bar_refresh_rate >= 20 in Trainer.
  warnings.warn(*args, **kwargs)
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
training on 8 TPU cores
Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    xm.set_replication(device, [device])
Exception in device=TPU:2: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 317, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
Exception in device=TPU:3: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices
    format(len(local_devices), len(kind_devices)))
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 317, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
RuntimeError: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices
    format(len(local_devices), len(kind_devices)))
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
Exception in device=TPU:4: Cannot replicate if number of devices (1) is different from 8
RuntimeError: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    xm.set_replication(device, [device])
Exception in device=TPU:5: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
Exception in device=TPU:6: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 317, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    xm.set_replication(device, [device])

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices
    format(len(local_devices), len(kind_devices)))
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 317, in set_replication
    replication_devices = xla_replication_devices(devices)
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices
    format(len(local_devices), len(kind_devices)))
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 317, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
RuntimeError: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices
    format(len(local_devices), len(kind_devices)))
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 317, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 317, in set_replication
    replication_devices = xla_replication_devices(devices)
RuntimeError: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices
    format(len(local_devices), len(kind_devices)))
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices
    format(len(local_devices), len(kind_devices)))
RuntimeError: Cannot replicate if number of devices (1) is different from 8
RuntimeError: Cannot replicate if number of devices (1) is different from 8
RuntimeError: Cannot replicate if number of devices (1) is different from 8
Exception in device=TPU:7: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):

  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 317, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices
    format(len(local_devices), len(kind_devices)))
RuntimeError: Cannot replicate if number of devices (1) is different from 8
---------------------------------------------------------------------------
ProcessExitedException                    Traceback (most recent call last)
<ipython-input-27-a434c45797df> in <module>()
      8 
      9 trainer = Trainer(tpu_cores=8, max_epochs=2)
---> 10 trainer.fit(model, trainDL, valDL)

4 frames
/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    158                     error_index=error_index,
    159                     error_pid=failed_process.pid,
--> 160                     exit_code=exitcode
    161                 )
    162 

ProcessExitedException: process 0 terminated with exit code 17

I install XLA and PyTorch like so:

PyTorch:
!pip install torch==1.5.0 torchvision==0.6.0

XLA:
VERSION = "20200325" #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

import torch_xla
import torch_xla.core.xla_model as xm

I installed an earlier version of PyTorch because I couldn’t import XLA with any later version, for some reason.

Then I pip installed Pytorch-Lightning and imported everything else.

Adding more confusion onto this problem, I’m pretty confident that I can connect to all 8 of the tpu cores - I ran a bit of test code to see if I could do some trivial tensor operation on all 8 of the cores, and I could.

for i in range(8):
    t = torch.randn(2, 2, device=xm.xla_device(i+1)) 
    print(t.device)
    print(t)

Returns:

xla:1
tensor([[ 0.3319, -0.9197],
        [ 0.2877,  0.8347]], device='xla:1')
xla:2
tensor([[ 0.3319, -0.9197],
        [ 0.2877,  0.8347]], device='xla:2')
xla:3
tensor([[ 0.3319, -0.9197],
        [ 0.2877,  0.8347]], device='xla:3')
xla:4
tensor([[ 0.3319, -0.9197],
        [ 0.2877,  0.8347]], device='xla:4')
xla:5
tensor([[ 0.3319, -0.9197],
        [ 0.2877,  0.8347]], device='xla:5')
xla:6
tensor([[ 0.3319, -0.9197],
        [ 0.2877,  0.8347]], device='xla:6')
xla:7
tensor([[ 0.3319, -0.9197],
        [ 0.2877,  0.8347]], device='xla:7')
xla:8
tensor([[ 0.3319, -0.9197],
        [ 0.2877,  0.8347]], device='xla:8')

I tried re-creating this error on a simpler model to no avail. I recreated it once and it didn’t work, but for another reason.

I really, really, need to connect to a TPU, and this seems like the best way in PyTorch - I’m just pretty clueless on what I should try next.

My Colab document is here. I’m very sorry about the length.

I’d be so very thankful if someone could help me - I’m in a bit of a rut!

it seems it is a bug on our end, mind submit an issue :rabbit:

This would occur is xm.xla_device() is called outside on xmp.spawn, on the main process. i.e if xm.xla_device(number) was called before trainer.fit().

@Lezwon …can you please elaborate a little bit more on your comment. I am also getting same error and not able to resolve it.

@Premal_Matalia refer this: