Skip to content

Commit aa4094a

Browse files
committedNov 2, 2020
PowerGraph => ExponentialGraph
1 parent cb619f7 commit aa4094a

16 files changed

+108
-132
lines changed
 

‎bluefog/common/basics.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def init(self, topology_fn: Optional[Callable[[int], networkx.DiGraph]] = None,
5151
Args:
5252
topology_fn: A callable function that takes size as input and return
5353
networkx.DiGraph object to decide the topology. If not provided
54-
a default power graph (base 2) structure is called.
54+
a default exponential graph (base 2) structure is called.
5555
is_weighted: If set to true, the neighbor ops like (win_update, neighbor_allreduce) will
5656
execute the weighted average instead, where the weight is the value used in
5757
topology matrix (including self).
@@ -60,7 +60,7 @@ def init(self, topology_fn: Optional[Callable[[int], networkx.DiGraph]] = None,
6060
if topology_fn:
6161
topo = topology_fn(self.size())
6262
else:
63-
topo = topology_util.PowerGraph(self.size())
63+
topo = topology_util.ExponentialGraph(self.size())
6464
self.set_topology(topo, is_weighted)
6565
atexit.register(self.shutdown)
6666

@@ -191,7 +191,7 @@ def set_topology(self, topology: Optional[networkx.DiGraph] = None,
191191
192192
Args:
193193
Topo: A networkx.DiGraph object to decide the topology. If not provided
194-
a default power graph (base 2) structure is used.
194+
a default exponential graph (base 2) structure is used.
195195
is_weighted: If set to true, the win_update and neighbor_allreduce will execute the
196196
weighted average instead, where the weights are the value used in topology matrix
197197
(including self weight). Note win_get/win_put/win_accumulate do not use this weight
@@ -207,10 +207,10 @@ def set_topology(self, topology: Optional[networkx.DiGraph] = None,
207207
>>> bf.set_topology(topology_util.RingGraph(bf.size()))
208208
"""
209209
if topology is None:
210-
topology = topology_util.PowerGraph(size=self.size())
210+
topology = topology_util.ExponentialGraph(size=self.size())
211211
if self.local_rank() == 0:
212212
logger.info(
213-
"Topology is not specified. Default Power Two topology is used.")
213+
"Topology is not specified. Default Exponential Two topology is used.")
214214

215215
if not isinstance(topology, networkx.DiGraph):
216216
raise TypeError("topology must be a networkx.DiGraph obejct.")

‎bluefog/common/mpi_controller.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ void MPIController::WinCreate(TensorTableEntry& entry) {
825825
win_manager->SetGlobalWin(global_mpi_win_ptr);
826826

827827
// Build extra buffers for win_put.
828-
// For example: size=4 power two ring topology
828+
// For example: size=4 exponential two ring topology
829829
// r\s 0 1 2 3
830830
// 0 g x x
831831
// 1 x g x

‎bluefog/common/topology_util.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,18 @@ def GetSendWeights(topo: nx.DiGraph, rank: int) -> Tuple[float, Dict[int, float]
6363
return self_weight, neighbor_weights
6464

6565

66-
def PowerTwoRingGraph(size: int) -> nx.DiGraph:
66+
def ExponentialTwoGraph(size: int) -> nx.DiGraph:
6767
"""Generate graph topology such that each points only
68-
connected to a point such that the index difference is power of 2.
68+
connected to a point such that the index difference is the power of 2.
6969
70-
Example: A PowerTwoRingGraph with 12 nodes:
70+
Example: A ExponentialTwoGraph with 12 nodes:
7171
7272
.. plot::
7373
:context: close-figs
7474
7575
>>> import networkx as nx
7676
>>> from bluefog.common import topology_util
77-
>>> G = topology_util.PowerTwoRingGraph(12)
77+
>>> G = topology_util.ExponentialTwoGraph(12)
7878
>>> nx.draw_circular(G)
7979
"""
8080
assert size > 0
@@ -96,18 +96,18 @@ def isPowerOf(x, base):
9696
return False
9797

9898

99-
def PowerGraph(size: int, base: int = 2) -> nx.DiGraph:
99+
def ExponentialGraph(size: int, base: int = 2) -> nx.DiGraph:
100100
"""Generate graph topology such that each points only
101101
connected to a point such that the index difference is power of base. (Default is 2)
102102
103-
Example: A PowerGraph with 12 nodes:
103+
Example: A ExponentialGraph with 12 nodes:
104104
105105
.. plot::
106106
:context: close-figs
107107
108108
>>> import networkx as nx
109109
>>> from bluefog.common import topology_util
110-
>>> G = topology_util.PowerGraph(12)
110+
>>> G = topology_util.ExponentialGraph(12)
111111
>>> nx.draw_circular(G)
112112
"""
113113
x = [1.0]
@@ -125,20 +125,20 @@ def PowerGraph(size: int, base: int = 2) -> nx.DiGraph:
125125
return G
126126

127127

128-
def SymmetricPowerGraph(size: int, base: int = 4) -> nx.DiGraph:
128+
def SymmetricExponentialGraph(size: int, base: int = 4) -> nx.DiGraph:
129129
"""
130130
Generate symmeteric graph topology such that for the first half of nodes
131131
only connected to a point such that the index difference is power of base (Default is 4)
132132
and the connectivity for the second half of nodes just mirrored to the first half.
133133
134-
Example: A SymmetricPowerGraph with 12 nodes
134+
Example: A SymmetricExponentialGraph with 12 nodes
135135
136136
.. plot::
137137
:context: close-figs
138138
139139
>>> import networkx as nx
140140
>>> from bluefog.common import topology_util
141-
>>> G = topology_util.SymmetricPowerGraph(12)
141+
>>> G = topology_util.SymmetricExponentialGraph(12)
142142
>>> nx.draw_circular(G)
143143
"""
144144
x = [1.0]
@@ -339,7 +339,7 @@ def InnerOuterRingGraph(world_size: int, local_size: int) -> nx.DiGraph:
339339
return G
340340

341341

342-
def InnerOuterExp2Graph(world_size: int, local_size: int) -> nx.DiGraph:
342+
def InnerOuterExpo2Graph(world_size: int, local_size: int) -> nx.DiGraph:
343343
"""Generate Inner Ring and Outer Exponential-2 Graph.
344344
345345
Within one machine all inner rank/processes is fully-connected and all
@@ -349,7 +349,7 @@ def InnerOuterExp2Graph(world_size: int, local_size: int) -> nx.DiGraph:
349349
350350
>>> import networkx as nx
351351
>>> from bluefog.common import topology_util
352-
>>> G = topology_util.InnerOuterExp2Graph(12, 3)
352+
>>> G = topology_util.InnerOuterExpo2Graph(12, 3)
353353
>>> nx.draw_circular(G)
354354
"""
355355
total_nodes = world_size
@@ -541,7 +541,7 @@ def GetInnerOuterRingDynamicSendRecvRanks(
541541
index += 1
542542

543543

544-
def GetInnerOuterExp2DynamicSendRecvRanks(
544+
def GetInnerOuterExpo2DynamicSendRecvRanks(
545545
world_size: int, local_size: int, self_rank: int
546546
) -> Iterator[Tuple[List[int], List[int]]]:
547547
"""
@@ -560,7 +560,7 @@ def GetInnerOuterExp2DynamicSendRecvRanks(
560560
561561
>>> from bluefog.common import topology_util
562562
>>> world_size, local_size = bf.size(), bf.local_size()
563-
>>> gen = topology_util.GetInnerOuterExp2DynamicSendRecvRanks(world_size, local_size, 0)
563+
>>> gen = topology_util.GetInnerOuterExpo2DynamicSendRecvRanks(world_size, local_size, 0)
564564
>>> for _ in range(10):
565565
>>> print(next(gen))
566566
"""

‎docs/timeline.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Example I: Logistic regression with neighbor_allreduce
3636
------------------------------------------------------
3737
In the first example, we show the timeline when running decentralized SGD for
3838
logistic regression, see the figure below. In this example, each rank is connected
39-
via an undirected power-2 topology. We exploit the
39+
via an undirected exponential-2 topology. We exploit the
4040
primitive ``neighbor_allreduce`` to perform the neighbor averaging.
4141

4242
.. image:: ./_static/bf_timeline_example1a.png

‎examples/pytorch_ImageNet_Resnet50.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@
7272
parser.add_argument("--disable-dynamic-topology", action="store_true",
7373
default=False, help=("Disable each iteration to transmit one neighbor " +
7474
"per iteration dynamically."))
75-
parser.add_argument('--virtual-topology', type=str, default="power2",
75+
parser.add_argument('--virtual-topology', type=str, default="expo2",
7676
help='The underlying virtual topology. Supporting options are ' +
77-
'[power2(Default), ring, mesh, star].')
77+
'[expo2(Default), ring, mesh, star].')
7878

7979
args = parser.parse_args()
8080
args.cuda = not args.no_cuda and torch.cuda.is_available()
@@ -88,13 +88,13 @@
8888
torch.manual_seed(args.seed)
8989

9090
if args.dist_optimizer != 'horovod':
91-
if args.virtual_topology == "power2":
91+
if args.virtual_topology == "expo2":
9292
pass
9393
elif args.virtual_topology == "ring":
9494
bf.set_topology(topology_util.RingGraph(bf.size(), connect_style=1))
9595
else:
9696
raise ValueError("Unknown args.virtual_topology, supporting options are " +
97-
"[power2(Default), ring, mesh, star].")
97+
"[expo2(Default), ring, mesh, star].")
9898

9999
if args.cuda:
100100
# Bluefog: pin GPU to local rank.

‎examples/pytorch_average_consensus.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
help='maximum iterations')
2828
parser.add_argument('--local-size', type=int, default=4,
2929
help='number of nodes per machine')
30-
parser.add_argument('--virtual-topology', type=str, default="power2",
30+
parser.add_argument('--virtual-topology', type=str, default="expo2",
3131
help='The underlying virtual topology. Supporting options are ' +
32-
'[power2(Default), ring, mesh, star, InnerOuterRing].')
32+
'[expo2(Default), ring, mesh, star, InnerOuterRing].')
3333
parser.add_argument('--asynchronous-mode', action='store_true', default=False,
3434
help='Use one-sided ops to run asynchronous push sum algorithm')
3535
parser.add_argument('--no-cuda', action='store_true', default=False,
@@ -58,12 +58,12 @@
5858
else:
5959
x = torch.randn(args.data_size, dtype=torch.double)
6060

61-
if args.virtual_topology == "power2":
61+
if args.virtual_topology == "expo2":
6262
pass
63-
elif args.virtual_topology == "power3":
64-
bf.set_topology(topology_util.PowerGraph(bf.size(), base=3))
65-
elif args.virtual_topology == "power4":
66-
bf.set_topology(topology_util.PowerGraph(bf.size(), base=4))
63+
elif args.virtual_topology == "expo3":
64+
bf.set_topology(topology_util.ExponentialGraph(bf.size(), base=3))
65+
elif args.virtual_topology == "expo4":
66+
bf.set_topology(topology_util.ExponentialGraph(bf.size(), base=4))
6767
elif args.virtual_topology == "ring":
6868
bf.set_topology(topology_util.RingGraph(bf.size(), connect_style=1))
6969
elif args.virtual_topology == "mesh":
@@ -77,7 +77,7 @@
7777
bf.set_topology(topology_util.FullyConnectedGraph(bf.size()))
7878
else:
7979
raise ValueError("Unknown args.virtual_topology, supporting options are " +
80-
"[power2(Default), ring, mesh, star].")
80+
"[expo2(Default), ring, mesh, star].")
8181

8282
x_bar = bf.allreduce(x, average=True)
8383
mse = [torch.norm(x-x_bar, p=2) / torch.norm(x_bar, p=2)]

‎examples/pytorch_benchmark.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@
6060
parser.add_argument('--disable-dynamic-topology', action='store_true',
6161
default=False, help=('Disable each iteration to transmit one neighbor ' +
6262
'per iteration dynamically.'))
63-
parser.add_argument('--virtual-topology', type=str, default="power2",
63+
parser.add_argument('--virtual-topology', type=str, default="expo2",
6464
help='The underlying virtual topology. Supporting options are ' +
65-
'[power2(Default), ring, mesh, star, InnerOuterRing, InnerOuterExp2].')
65+
'[expo2(Default), ring, mesh, star, InnerOuterRing, InnerOuterExpo2].')
6666

6767

6868
args = parser.parse_args()
@@ -73,21 +73,21 @@
7373

7474
bf.init()
7575
if args.dist_optimizer != 'horovod':
76-
if args.virtual_topology == "power2":
76+
if args.virtual_topology == "expo2":
7777
pass
7878
elif args.virtual_topology == "ring":
7979
bf.set_topology(topology_util.RingGraph(bf.size(), connect_style=1))
8080
elif args.virtual_topology == "InnerOuterRing":
8181
assert bf.is_homogeneous, "InnerOuterRing topo should be used only homogeneous environment"
8282
bf.set_topology(topology_util.InnerOuterRingGraph(
8383
bf.size(), local_size=bf.local_size() if args.local_size == -1 else args.local_size))
84-
elif args.virtual_topology == "InnerOuterExp2":
85-
assert bf.is_homogeneous, "InnerOuterExp2 topo should be used under homogeneous environment"
86-
bf.set_topology(topology_util.InnerOuterExp2Graph(
84+
elif args.virtual_topology == "InnerOuterExpo2":
85+
assert bf.is_homogeneous, "InnerOuterExpo2 topo should be used under homogeneous environment"
86+
bf.set_topology(topology_util.InnerOuterExpo2Graph(
8787
bf.size(), local_size=bf.local_size() if args.local_size == -1 else args.local_size))
8888
else:
8989
raise ValueError("Unknown args.virtual_topology, supporting options are " +
90-
"[power2(Default), ring, mesh, star,InnerOuterRing, InnerOuterExp2].")
90+
"[expo2(Default), ring, mesh, star,InnerOuterRing, InnerOuterExpo2].")
9191

9292
if args.cuda:
9393
torch.cuda.set_device(bf.local_rank())
@@ -176,8 +176,8 @@ def forward(self, x):
176176
bf.size(),
177177
local_size=bf.local_size() if args.local_size == -1 else args.local_size,
178178
self_rank=bf.rank())
179-
elif args.virtual_topology == 'InnerOuterExp2':
180-
dynamic_neighbor_allreduce_gen = topology_util.GetInnerOuterExp2DynamicSendRecvRanks(
179+
elif args.virtual_topology == 'InnerOuterExpo2':
180+
dynamic_neighbor_allreduce_gen = topology_util.GetInnerOuterExpo2DynamicSendRecvRanks(
181181
bf.size(),
182182
local_size=bf.local_size() if args.local_size == -1 else args.local_size,
183183
self_rank=bf.rank())

‎examples/pytorch_cifar10_resnet.py

+10-34
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@
8484
parser.add_argument('--disable-dynamic-topology', action='store_true',
8585
default=False, help=('Disable each iteration to transmit one neighbor ' +
8686
'per iteration dynamically.'))
87-
parser.add_argument('--virtual-topology', type=str, default="power2",
87+
parser.add_argument('--virtual-topology', type=str, default="expo2",
8888
help='The underlying virtual topology. Supporting options are ' +
89-
'[power2(Default), ring, mesh, star].')
89+
'[expo2(Default), ring, mesh, star].')
9090

9191
args = parser.parse_args()
9292
args.cuda = (not args.no_cuda) and (torch.cuda.is_available())
@@ -100,21 +100,21 @@
100100
bf.init()
101101
torch.manual_seed(args.seed)
102102
if args.dist_optimizer != 'horovod':
103-
if args.virtual_topology == "power2":
103+
if args.virtual_topology == "expo2":
104104
pass
105105
elif args.virtual_topology == "ring":
106106
bf.set_topology(topology_util.RingGraph(bf.size(), connect_style=1))
107107
elif args.virtual_topology == "InnerOuterRing":
108108
assert bf.is_homogeneous, "InnerOuterRing topo should be used only homogeneous environment"
109109
bf.set_topology(topology_util.InnerOuterRingGraph(
110110
bf.size(), local_size=bf.local_size()))
111-
elif args.virtual_topology == "InnerOuterExp2":
112-
assert bf.is_homogeneous, "InnerOuterExp2 topo should be used under homogeneous environment"
113-
bf.set_topology(topology_util.InnerOuterExp2Graph(
111+
elif args.virtual_topology == "InnerOuterExpo2":
112+
assert bf.is_homogeneous, "InnerOuterExpo2 topo should be used under homogeneous environment"
113+
bf.set_topology(topology_util.InnerOuterExpo2Graph(
114114
bf.size(), local_size=bf.local_size()))
115115
else:
116116
raise ValueError("Unknown args.virtual_topology, supporting options are " +
117-
"[power2(Default), ring, mesh, star,InnerOuterRing, InnerOuterExp2].")
117+
"[expo2(Default), ring, mesh, star,InnerOuterRing, InnerOuterExpo2].")
118118

119119
if args.cuda:
120120
print("using cuda.")
@@ -126,21 +126,6 @@
126126

127127
cudnn.benchmark = True
128128

129-
# If set > 0, will resume training from a given checkpoint.
130-
resume_from_epoch = 0
131-
# for try_epoch in range(args.epochs, 0, -1):
132-
# if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)):
133-
# resume_from_epoch = try_epoch
134-
# break
135-
136-
# Bluefog: broadcast resume_from_epoch from rank 0 (which will have
137-
# checkpoints) to other ranks.
138-
resume_from_epoch = bf.broadcast(
139-
torch.tensor(resume_from_epoch), # pylint: disable=not-callable
140-
root_rank=0,
141-
name="resume_from_epoch",
142-
).item()
143-
144129
# Bluefog: print logs on the first worker.
145130
verbose = 1 if bf.rank() == 0 else 0
146131

@@ -234,15 +219,6 @@
234219
'[neighbor_allreduce, gradient_allreduce, allreduce, ' +
235220
'win_put, horovod]')
236221

237-
print("resume_from_epoch: ", resume_from_epoch)
238-
# Restore from a previous checkpoint, if initial_epoch is specified.
239-
# Bluefog: restore on the first worker which will broadcast weights to other workers.
240-
# if resume_from_epoch > 0 and bf.rank() == 0:
241-
# filepath = args.checkpoint_format.format(epoch=resume_from_epoch)
242-
# checkpoint = torch.load(filepath)
243-
# model.load_state_dict(checkpoint["model"])
244-
# optimizer.load_state_dict(checkpoint["optimizer"])
245-
246222
# Bluefog: broadcast parameters & optimizer state.
247223
bf.broadcast_parameters(model.state_dict(), root_rank=0)
248224
bf.broadcast_optimizer_state(optimizer, root_rank=0)
@@ -345,8 +321,8 @@ def adjust_learning_rate(epoch, batch_idx):
345321
bf.size(),
346322
local_size=bf.local_size(),
347323
self_rank=bf.rank())
348-
elif args.virtual_topology == 'InnerOuterExp2':
349-
dynamic_neighbor_allreduce_gen = topology_util.GetInnerOuterExp2DynamicSendRecvRanks(
324+
elif args.virtual_topology == 'InnerOuterExpo2':
325+
dynamic_neighbor_allreduce_gen = topology_util.GetInnerOuterExpo2DynamicSendRecvRanks(
350326
bf.size(),
351327
local_size=bf.local_size(),
352328
self_rank=bf.rank())
@@ -421,7 +397,7 @@ def avg(self):
421397
return self.sum / self.n
422398

423399

424-
for epoch in range(resume_from_epoch, args.epochs):
400+
for epoch in range(args.epochs):
425401
train(epoch)
426402
validate(epoch)
427403
# save_checkpoint(epoch)

0 commit comments

Comments
 (0)
Please sign in to comment.