forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcommon.py
More file actions
100 lines (82 loc) · 3.8 KB
/
common.py
File metadata and controls
100 lines (82 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import time
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import pytest
# Worker timeout *after* the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 10
def distributed_test(world_size=2, backend='nccl'):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of
torch.distributed, and catching of errors.
Usage example:
@distributed_test(worker_size=[2,3])
def my_test():
rank = dist.get_rank()
world_size = dist.get_world_size()
assert(rank < world_size)
Arguments:
world_size (int or list): number of ranks to spawn. Can be a list to spawn
multiple tests.
"""
def dist_wrap(run_func):
"""Second-level decorator for dist_test. This actually wraps the function. """
def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
"""Initialize torch.distributed and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend=backend,
init_method='env://',
rank=local_rank,
world_size=num_procs)
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
run_func(*func_args, **func_kwargs)
def dist_launcher(num_procs, *func_args, **func_kwargs):
"""Launch processes and gracefully handle failures. """
# Spawn all workers on subprocesses.
processes = []
for local_rank in range(num_procs):
p = Process(target=dist_init,
args=(local_rank,
num_procs,
*func_args),
kwargs=func_kwargs)
p.start()
processes.append(p)
# Now loop and wait for a test to complete. The spin-wait here isn't a big
# deal because the number of processes will be O(#GPUs) << O(#CPUs).
any_done = False
while not any_done:
for p in processes:
if not p.is_alive():
any_done = True
break
# Wait for all other processes to complete
for p in processes:
p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT)
failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
for rank, p in failed:
# If it still hasn't terminated, kill it because it hung.
if p.exitcode is None:
p.terminate()
pytest.fail(f'Worker {rank} hung.', pytrace=False)
if p.exitcode < 0:
pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
pytrace=False)
if p.exitcode > 0:
pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
pytrace=False)
def run_func_decorator(*func_args, **func_kwargs):
"""Entry point for @distributed_test(). """
if isinstance(world_size, int):
dist_launcher(world_size, *func_args, **func_kwargs)
elif isinstance(world_size, list):
for procs in world_size:
dist_launcher(procs, *func_args, **func_kwargs)
time.sleep(0.5)
else:
raise TypeError(f'world_size must be an integer or a list of integers.')
return run_func_decorator
return dist_wrap