forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_dist.py
More file actions
38 lines (29 loc) · 1.11 KB
/
test_dist.py
File metadata and controls
38 lines (29 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.distributed as dist
from common import distributed_test
import pytest
@distributed_test(world_size=3)
def test_init():
assert dist.is_initialized()
assert dist.get_world_size() == 3
assert dist.get_rank() < 3
# Demonstration of pytest's paramaterization
@pytest.mark.parametrize('number,color', [(1138, 'purple')])
def test_dist_args(number, color):
"""Outer test function with inputs from pytest.mark.parametrize(). Uses a distributed
helper function.
"""
@distributed_test(world_size=2)
def _test_dist_args_helper(x, color='red'):
assert dist.get_world_size() == 2
assert x == 1138
assert color == 'purple'
"""Ensure that we can parse args to distributed_test decorated functions. """
_test_dist_args_helper(number, color=color)
@distributed_test(world_size=[1, 2, 4])
def test_dist_allreduce():
x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
result = torch.ones(1, 3).cuda() * sum_of_ranks
dist.all_reduce(x)
assert torch.all(x == result)