I am having issues with distributed training when one of the nodes fails but the error takes time to be propagated to the other nodes. I am using the
NCCL_ASYNC_ERROR_HANDLING flag and I do see failures after the timeout is reached. The problem that I have is that I need to use a high value of timeout (because sometimes in my workflow, one of the workers needs to do some extra work and I do not want the other workers to die during this processing). I would like errors to be propagated almost instantly. Is there a way to do so? I’d be ok with doing it manually e.g. a
try / except on my training loop and something that communicates the error to the other nodes in the
except. I tried to destroy the process group or force a SIGKILL to no avail.
@nicomng I am not aware of any straightforward way to do what you are asking for, in all the cases.
However, you might want to look at torch elastic (now called torchrun). It comes preinstalled with the latest versions of PyTorch. If you run your script with the torchrun command (refer to the docs how to set num processes, nodes etc), then it can monitor your processes and if one of them fails, torchrun will kill the others and report back to you. Here is an example how to run it:
torchrun --nproc_per_node=4 --nnodes=2 --master_port ... --master_addr ... train_script.py
(double check for typos)
If train_script.py uses a Lightning Trainer, it will parse the environment variable that torchrun sets automatically for you. All you have to do is make sure that your Trainer settings match what you add in the torchrun command above (devices and num_nodes).
Let me know how it goes!