Skip to content

Commit 8bde896

Browse files
authored
Topo service (#86)
* Add construct_topology function * Add doc and test for infer_destination_source_ranks * Address comments * Split infer_topo into two functions * Delete construct_topo.py * Fix the none case in _infer_topo func
1 parent dcf56ba commit 8bde896

File tree

4 files changed

+229
-27
lines changed

4 files changed

+229
-27
lines changed

bluefog/torch/__init__.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from __future__ import absolute_import
18-
from __future__ import division
19-
from __future__ import print_function
20-
2117
import collections
2218
import os
2319
import torch
@@ -30,10 +26,10 @@
3026
DistributedWinPutOptimizer,
3127
DistributedAllreduceOptimizer,
3228
DistributedNeighborAllreduceOptimizer,
33-
DistributedHierarchicalNeighborAllreduceOptimizer
29+
DistributedHierarchicalNeighborAllreduceOptimizer,
3430
)
3531

36-
check_extension('bluefog.torch', __file__, 'mpi_lib')
32+
check_extension("bluefog.torch", __file__, "mpi_lib")
3733

3834
from bluefog.torch.mpi_ops import init, shutdown
3935
from bluefog.torch.mpi_ops import size, local_size, rank, local_rank
@@ -74,4 +70,38 @@
7470

7571
from bluefog.torch.mpi_ops import timeline_start_activity, timeline_end_activity
7672
from bluefog.torch.mpi_ops import timeline_context
77-
from bluefog.torch.utility import broadcast_optimizer_state, broadcast_parameters, allreduce_parameters
73+
74+
from bluefog.torch.utility import (
75+
broadcast_optimizer_state,
76+
broadcast_parameters,
77+
allreduce_parameters,
78+
)
79+
80+
from bluefog.common.topology_util import (
81+
GetRecvWeights,
82+
GetSendWeights,
83+
IsRegularGraph,
84+
IsTopologyEquivalent,
85+
)
86+
87+
from bluefog.common.topology_util import (
88+
ExponentialTwoGraph,
89+
ExponentialGraph,
90+
FullyConnectedGraph,
91+
MeshGrid2DGraph,
92+
RingGraph,
93+
StarGraph,
94+
SymmetricExponentialGraph,
95+
)
96+
97+
from bluefog.common.topology_util import (
98+
GetDynamicOnePeerSendRecvRanks,
99+
GetExp2DynamicSendRecvMachineRanks,
100+
GetInnerOuterRingDynamicSendRecvRanks,
101+
GetInnerOuterExpo2DynamicSendRecvRanks,
102+
)
103+
104+
from bluefog.torch.topology_util import (
105+
InferSourceFromDestinationRanks,
106+
InferDestinationFromSourceRanks,
107+
)

bluefog/torch/topology_util.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from typing import Any, List, Optional, Tuple, Union
2+
import collections
3+
4+
import numpy as np
5+
import torch
6+
import bluefog.torch as bf
7+
8+
9+
def _check_ranks(rank_list: List[Any], self_rank: int, size: int) -> [bool, str]:
10+
for rank in rank_list:
11+
if not isinstance(rank, int):
12+
return False, "contain element that is not integer."
13+
if (rank < 0) or (rank >= size):
14+
return False, "contain element that is not between 0 and size-1."
15+
if len(set(rank_list)) != len(rank_list):
16+
return False, "contain duplicated elements."
17+
if self_rank in rank_list:
18+
return False, "contain self rank."
19+
return True, ""
20+
21+
22+
def InferSourceFromDestinationRanks(
23+
dst_ranks: List[int], construct_adjacency_matrix: bool = False,
24+
) -> Union[List[int], Tuple[List[int], np.array]]:
25+
"""Infer the source ranks from destination ranks. This is collective communication call.
26+
27+
Args:
28+
dst_ranks: A list of destination ranks.
29+
construct_adjacency_matrix: If true, adjacency matrix will be return instead.
30+
Element w_{ij} represents the weights sending from node i to node j.
31+
We use column normalized style, i.e. the sum of receiving weight is 1.
32+
33+
Raises:
34+
ValueError: If dst_ranks or src_ranks does not contain integer from 0 to size-1.
35+
36+
Returns:
37+
If construct_adjacency_matrix is false, returns the source ranks list.
38+
If construct_adjacency_matrix is true, returns the the source ranks list
39+
and a 2-D numpy array.
40+
"""
41+
is_valid, error_msg = _check_ranks(dst_ranks, bf.rank(), bf.size())
42+
assert is_valid, f"The format of dst_ranks is wrong: {error_msg}"
43+
return _infer_topo(
44+
dst_ranks,
45+
transpose=False,
46+
construct_adjacency_matrix=construct_adjacency_matrix,
47+
)
48+
49+
50+
def InferDestinationFromSourceRanks(
51+
src_ranks: List[int], construct_adjacency_matrix: bool = False,
52+
) -> Union[List[int], np.array]:
53+
"""Infer the destination ranks from source ranks. This is collective communication call.
54+
55+
Args:
56+
src_ranks: A list of destination ranks.
57+
construct_adjacency_matrix: If true, adjacency matrix will be return instead.
58+
Element w_{ij} represents the weights sending from node i to node j.
59+
We use column normalized style, i.e. the sum of receiving weight is 1.
60+
61+
Raises:
62+
ValueError: If dst_ranks or src_ranks does not contain integer from 0 to size-1.
63+
64+
Returns:
65+
If construct_adjacency_matrix is false, returns the destination ranks list.
66+
If construct_adjacency_matrix is true, returns the the sodestinationrce ranks
67+
list and a 2-D numpy array.
68+
"""
69+
is_valid, error_msg = _check_ranks(src_ranks, bf.rank(), bf.size())
70+
assert is_valid, f"The format of src_ranks is wrong: {error_msg}"
71+
return _infer_topo(
72+
src_ranks,
73+
transpose=True,
74+
construct_adjacency_matrix=construct_adjacency_matrix,
75+
)
76+
77+
78+
def _infer_topo(
79+
rank_list: List[int], transpose: bool, construct_adjacency_matrix: bool
80+
):
81+
degree = len(rank_list)
82+
all_degree_list = bf.allgather(torch.tensor([degree], dtype=torch.int32)).numpy()
83+
all_rank_list = bf.allgather(torch.tensor(rank_list, dtype=torch.int32)).numpy()
84+
adjacency_dict = dict()
85+
displacement = 0
86+
for i, degree in enumerate(all_degree_list):
87+
adjacency_dict[i] = sorted(all_rank_list[displacement : displacement + degree])
88+
displacement += degree
89+
90+
inv_adjacency_dict = collections.defaultdict(list)
91+
for k, adj in adjacency_dict.items():
92+
for v in adj:
93+
inv_adjacency_dict[v].append(k)
94+
return_list = inv_adjacency_dict.get(bf.rank())
95+
if return_list is None:
96+
return_list = []
97+
98+
if not construct_adjacency_matrix:
99+
return return_list
100+
101+
# construct_adjacency_matrix
102+
W = np.eye(bf.size())
103+
for k, adj in adjacency_dict.items():
104+
W[k, adj] = 1
105+
if transpose:
106+
W = W.T
107+
108+
return return_list, W / W.sum(axis=1)

bluefog/torch/utility.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17+
from typing import Any, List, Optional
1718
import collections
1819

20+
import numpy as np
1921
import torch
2022
import bluefog.torch as bf
2123

test/torch_basics_test.py

+82-20
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,18 @@
2828

2929
from common import mpi_env_rank_and_size
3030
import bluefog.torch as bf
31-
from bluefog.common.topology_util import ExponentialGraph, RingGraph, RingGraph
32-
from bluefog.common.topology_util import IsTopologyEquivalent
31+
from bluefog.torch import (
32+
ExponentialGraph,
33+
RingGraph,
34+
StarGraph,
35+
MeshGrid2DGraph,
36+
FullyConnectedGraph,
37+
)
38+
from bluefog.torch import (
39+
IsTopologyEquivalent,
40+
InferDestinationFromSourceRanks,
41+
InferSourceFromDestinationRanks,
42+
)
3343

3444
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
3545
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
@@ -75,10 +85,12 @@ def test_set_topology_fail_with_win_create(self):
7585

7686
if size == 1:
7787
expected_topology = nx.from_numpy_array(
78-
np.array([[0.5]]), create_using=nx.DiGraph)
88+
np.array([[0.5]]), create_using=nx.DiGraph
89+
)
7990
elif size == 2:
8091
expected_topology = nx.from_numpy_array(
81-
np.array([[0, 0.2], [0.2, 0]]), create_using=nx.DiGraph)
92+
np.array([[0, 0.2], [0.2, 0]]), create_using=nx.DiGraph
93+
)
8294
else:
8395
expected_topology = RingGraph(size)
8496

@@ -96,10 +108,16 @@ def test_set_and_load_topology(self):
96108
bf.init()
97109
size = bf.size()
98110
if size == 4:
99-
expected_topology = nx.DiGraph(np.array(
100-
[[1/3., 1/3., 1/3., 0.], [0., 1/3., 1/3., 1/3.],
101-
[1/3., 0., 1/3., 1/3.], [1/3., 1/3., 0., 1/3.]]
102-
))
111+
expected_topology = nx.DiGraph(
112+
np.array(
113+
[
114+
[1 / 3.0, 1 / 3.0, 1 / 3.0, 0.0],
115+
[0.0, 1 / 3.0, 1 / 3.0, 1 / 3.0],
116+
[1 / 3.0, 0.0, 1 / 3.0, 1 / 3.0],
117+
[1 / 3.0, 1 / 3.0, 0.0, 1 / 3.0],
118+
]
119+
)
120+
)
103121
elif size == 1:
104122
expected_topology = nx.DiGraph(np.array([[1.0]]))
105123
else:
@@ -113,37 +131,81 @@ def test_in_out_neighbors_expo2(self):
113131
rank = bf.rank()
114132
size = bf.size()
115133
assert bf.set_topology(ExponentialGraph(size))
116-
in_neighobrs = bf.in_neighbor_ranks()
134+
in_neighbors = bf.in_neighbor_ranks()
117135
out_neighbors = bf.out_neighbor_ranks()
118136

119137
degree = int(np.ceil(np.log2(size)))
120-
expected_in_neighbors = sorted([(rank - 2**i) %
121-
size for i in range(degree)])
122-
expected_out_neighbors = sorted([(rank + 2**i) %
123-
size for i in range(degree)])
124-
assert sorted(in_neighobrs) == expected_in_neighbors
138+
expected_in_neighbors = sorted([(rank - 2 ** i) % size for i in range(degree)])
139+
expected_out_neighbors = sorted([(rank + 2 ** i) % size for i in range(degree)])
140+
assert sorted(in_neighbors) == expected_in_neighbors
125141
assert sorted(out_neighbors) == expected_out_neighbors
126142

127143
def test_in_out_neighbors_biring(self):
128144
bf.init()
129145
rank = bf.rank()
130146
size = bf.size()
131147
assert bf.set_topology(RingGraph(size))
132-
in_neighobrs = bf.in_neighbor_ranks()
148+
in_neighbors = bf.in_neighbor_ranks()
133149
out_neighbors = bf.out_neighbor_ranks()
134150

135-
expected_in_neighbors = list(set(
136-
map(lambda x: x % size, [rank - 1, rank + 1])))
137-
expected_out_neighbors = list(set(
138-
map(lambda x: x % size, [rank - 1, rank + 1])))
151+
expected_in_neighbors = list(set(map(lambda x: x % size, [rank - 1, rank + 1])))
152+
expected_out_neighbors = list(
153+
set(map(lambda x: x % size, [rank - 1, rank + 1]))
154+
)
139155

140156
if size <= 1:
141157
expected_in_neighbors = []
142158
expected_out_neighbors = []
143159

144-
assert sorted(in_neighobrs) == expected_in_neighbors
160+
assert sorted(in_neighbors) == expected_in_neighbors
145161
assert sorted(out_neighbors) == expected_out_neighbors
146162

147163

164+
@pytest.mark.parametrize(
165+
"topo_func",
166+
[ExponentialGraph, RingGraph, StarGraph, MeshGrid2DGraph, FullyConnectedGraph],
167+
)
168+
def test_infer_destination_from_source_ranks(topo_func):
169+
bf.init()
170+
size = bf.size()
171+
bf.set_topology(topo_func(size))
172+
topo = bf.load_topology()
173+
in_neighbors = bf.in_neighbor_ranks()
174+
out_neighbors = bf.out_neighbor_ranks()
175+
176+
# Make the W into average rule.
177+
expected_W = (nx.to_numpy_array(topo) > 0).astype(float)
178+
expected_W /= expected_W.sum(axis=0)
179+
180+
src_ranks, W = InferDestinationFromSourceRanks(
181+
src_ranks=in_neighbors, construct_adjacency_matrix=True
182+
)
183+
assert sorted(src_ranks) == out_neighbors
184+
np.testing.assert_allclose(W, expected_W)
185+
186+
187+
@pytest.mark.parametrize(
188+
"topo_func",
189+
[ExponentialGraph, RingGraph, StarGraph, MeshGrid2DGraph, FullyConnectedGraph],
190+
)
191+
def test_infer_source_from_destination_ranks(topo_func):
192+
bf.init()
193+
size = bf.size()
194+
bf.set_topology(topo_func(size))
195+
topo = bf.load_topology()
196+
in_neighbors = bf.in_neighbor_ranks()
197+
out_neighbors = bf.out_neighbor_ranks()
198+
199+
# Make the W into average rule.
200+
expected_W = (nx.to_numpy_array(topo) > 0).astype(float)
201+
expected_W /= expected_W.sum(axis=0)
202+
203+
dst_ranks, W = InferSourceFromDestinationRanks(
204+
dst_ranks=out_neighbors, construct_adjacency_matrix=True
205+
)
206+
assert sorted(dst_ranks) == in_neighbors
207+
np.testing.assert_allclose(W, expected_W)
208+
209+
148210
if __name__ == "__main__":
149211
unittest.main()

0 commit comments

Comments
 (0)