Skip to content

Commit 14ce22a

Browse files
committed
examples/async and model_pruing passed ty checks; server aggregation migrated to the new strategy API in examples/async.
1 parent 29fa2b3 commit 14ce22a

File tree

13 files changed

+361
-748
lines changed

13 files changed

+361
-748
lines changed

examples/async/fedasync/fedasync_server.py

Lines changed: 15 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from plato.config import Config
1515
from plato.servers import fedavg
16+
from plato.servers.strategies import FedAsyncAggregationStrategy
1617

1718

1819
class Server(fedavg.Server):
@@ -26,99 +27,45 @@ def __init__(
2627
trainer=None,
2728
callbacks=None,
2829
):
30+
aggregation_strategy = FedAsyncAggregationStrategy()
31+
2932
super().__init__(
3033
model=model,
3134
datasource=datasource,
3235
algorithm=algorithm,
3336
trainer=trainer,
3437
callbacks=callbacks,
38+
aggregation_strategy=aggregation_strategy,
3539
)
3640

37-
# The hyperparameter of FedAsync with a range of (0, 1)
38-
self.mixing_hyperparam = 1
39-
40-
# Whether adjust mixing hyperparameter after each round
41-
self.adaptive_mixing = False
42-
4341
def configure(self) -> None:
4442
"""Configure the mixing hyperparameter for the server, as well as
4543
other parameters from the configuration file.
4644
"""
4745
super().configure()
4846

49-
# Configuring the mixing hyperparameter for FedAsync
50-
self.adaptive_mixing = (
51-
hasattr(Config().server, "adaptive_mixing")
52-
and Config().server.adaptive_mixing
53-
)
54-
5547
if not hasattr(Config().server, "mixing_hyperparameter"):
5648
logging.warning(
5749
"FedAsync: Variable mixing hyperparameter is required for the FedAsync server."
5850
)
5951
else:
60-
self.mixing_hyperparam = Config().server.mixing_hyperparameter
52+
try:
53+
mixing_hyperparam = float(Config().server.mixing_hyperparameter)
54+
except (TypeError, ValueError):
55+
logging.warning(
56+
"FedAsync: Invalid mixing hyperparameter. "
57+
"Unable to cast %s to float.",
58+
Config().server.mixing_hyperparameter,
59+
)
60+
return
6161

62-
if 0 < self.mixing_hyperparam < 1:
62+
if 0 < mixing_hyperparam < 1:
6363
logging.info(
6464
"FedAsync: Mixing hyperparameter is set to %s.",
65-
self.mixing_hyperparam,
65+
mixing_hyperparam,
6666
)
6767
else:
6868
logging.warning(
6969
"FedAsync: Invalid mixing hyperparameter. "
7070
"The hyperparameter needs to be between 0 and 1 (exclusive)."
7171
)
72-
73-
async def aggregate_weights(self, updates, baseline_weights, weights_received):
74-
"""Process the client reports by aggregating their weights."""
75-
# Calculate the new mixing hyperparameter with client's staleness
76-
client_staleness = updates[0].staleness
77-
78-
if self.adaptive_mixing:
79-
self.mixing_hyperparam *= self._staleness_function(client_staleness)
80-
81-
return await self.algorithm.aggregate_weights(
82-
baseline_weights, weights_received, mixing=self.mixing_hyperparam
83-
)
84-
85-
@staticmethod
86-
def _staleness_function(staleness) -> float:
87-
"""Staleness function used to adjust the mixing hyperparameter"""
88-
if hasattr(Config().server, "staleness_weighting_function"):
89-
staleness_func_param = Config().server.staleness_weighting_function
90-
func_type = staleness_func_param.type.lower()
91-
if func_type == "constant":
92-
return Server._constant_function()
93-
elif func_type == "polynomial":
94-
a = staleness_func_param.a
95-
return Server._polynomial_function(staleness, a)
96-
elif func_type == "hinge":
97-
a = staleness_func_param.a
98-
b = staleness_func_param.b
99-
return Server._hinge_function(staleness, a, b)
100-
else:
101-
logging.warning(
102-
"FedAsync: Unknown staleness weighting function type. "
103-
"Type needs to be constant, polynomial, or hinge."
104-
)
105-
else:
106-
return Server.constant_function()
107-
108-
@staticmethod
109-
def _constant_function() -> float:
110-
"""Constant staleness function as proposed in Sec. 5.2, Evaluation Setup."""
111-
return 1
112-
113-
@staticmethod
114-
def _polynomial_function(staleness, a) -> float:
115-
"""Polynomial staleness function as proposed in Sec. 5.2, Evaluation Setup."""
116-
return (staleness + 1) ** -a
117-
118-
@staticmethod
119-
def _hinge_function(staleness, a, b) -> float:
120-
"""Hinge staleness function as proposed in Sec. 5.2, Evaluation Setup."""
121-
if staleness <= b:
122-
return 1
123-
else:
124-
return 1 / (a * (staleness - b) + 1)

examples/async/port/port_server.py

Lines changed: 20 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -7,139 +7,35 @@
77
88
"""
99

10-
import asyncio
11-
import copy
12-
import logging
13-
import os
14-
15-
import torch
16-
import torch.nn.functional as F
17-
18-
from plato.config import Config
1910
from plato.servers import fedavg
11+
from plato.servers.strategies import PortAggregationStrategy
2012

2113

2214
class Server(fedavg.Server):
23-
"""A federated learning server using the FedAsync algorithm."""
24-
25-
async def cosine_similarity(self, update, staleness):
26-
"""Compute the cosine similarity of the received updates and the difference
27-
between the current and a previous model according to client staleness."""
28-
# Loading the global model from a previous round according to staleness
29-
filename = f"model_{self.current_round - 2}.pth"
30-
model_path = Config().params["model_path"]
31-
model_path = f"{model_path}/{filename}"
32-
33-
similarity = 1.0
34-
35-
if staleness > 1 and os.path.exists(model_path):
36-
previous_model = copy.deepcopy(self.trainer.model)
37-
previous_model.load_state_dict(torch.load(model_path))
38-
39-
previous = torch.zeros(0)
40-
for __, weight in previous_model.cpu().state_dict().items():
41-
previous = torch.cat((previous, weight.view(-1)))
42-
43-
current = torch.zeros(0)
44-
for __, weight in self.trainer.model.cpu().state_dict().items():
45-
current = torch.cat((current, weight.view(-1)))
46-
47-
deltas = torch.zeros(0)
48-
for __, delta in update.items():
49-
deltas = torch.cat((deltas, delta.view(-1)))
50-
51-
similarity = F.cosine_similarity(current - previous, deltas, dim=0)
52-
53-
return similarity
54-
55-
async def aggregate_deltas(self, updates, deltas_received):
56-
"""Aggregate weight updates from the clients using federated averaging."""
57-
# Extract the total number of samples
58-
self.total_samples = sum(update.report.num_samples for update in updates)
59-
60-
# Constructing the aggregation weights to be used
61-
aggregation_weights = []
62-
63-
for i, update in enumerate(deltas_received):
64-
report = updates[i].report
65-
staleness = updates[i].staleness
66-
num_samples = report.num_samples
67-
68-
similarity = await self.cosine_similarity(update, staleness)
69-
staleness_factor = Server.staleness_function(staleness)
70-
71-
similarity_weight = (
72-
Config().server.similarity_weight
73-
if hasattr(Config().server, "similarity_weight")
74-
else 1
75-
)
76-
staleness_weight = (
77-
Config().server.staleness_weight
78-
if hasattr(Config().server, "staleness_weight")
79-
else 1
80-
)
81-
82-
logging.info("[Client %s] similarity: %s", i, (similarity + 1) / 2)
83-
logging.info(
84-
"[Client %s] staleness: %s, staleness factor: %s",
85-
i,
86-
staleness,
87-
staleness_factor,
88-
)
89-
raw_weight = (
90-
num_samples
91-
/ self.total_samples
92-
* (
93-
(similarity + 1) / 2 * similarity_weight
94-
+ staleness_factor * staleness_weight
95-
)
96-
)
97-
logging.info("[Client %s] raw weight = %s", i, raw_weight)
98-
99-
aggregation_weights.append(raw_weight)
100-
101-
# Normalize so that the sum of aggregation weights equals 1
102-
aggregation_weights = [
103-
i / sum(aggregation_weights) for i in aggregation_weights
104-
]
105-
106-
logging.info(
107-
"[Server #%s] normalized aggregation weights: %s",
108-
os.getpid(),
109-
aggregation_weights,
15+
"""A federated learning server using the Port aggregation strategy."""
16+
17+
def __init__(
18+
self,
19+
model=None,
20+
datasource=None,
21+
algorithm=None,
22+
trainer=None,
23+
callbacks=None,
24+
):
25+
super().__init__(
26+
model=model,
27+
datasource=datasource,
28+
algorithm=algorithm,
29+
trainer=trainer,
30+
callbacks=callbacks,
31+
aggregation_strategy=PortAggregationStrategy(),
11032
)
11133

112-
# Perform weighted averaging
113-
avg_update = {
114-
name: self.trainer.zeros(weights.shape)
115-
for name, weights in deltas_received[0].items()
116-
}
117-
118-
for i, update in enumerate(deltas_received):
119-
for name, delta in update.items():
120-
avg_update[name] += delta * aggregation_weights[i]
121-
122-
# Yield to other tasks in the server
123-
await asyncio.sleep(0)
124-
125-
return avg_update
126-
12734
def weights_aggregated(self, updates):
12835
"""
12936
Method called at the end of aggregating received weights.
13037
"""
13138
# Save the current model for later retrieval when cosine similarity needs to be computed
13239
filename = f"model_{self.current_round}.pth"
133-
self.trainer.save_model(filename)
134-
135-
@staticmethod
136-
def staleness_function(staleness):
137-
"""The staleness_function."""
138-
staleness_bound = 10
139-
140-
if hasattr(Config().server, "staleness_bound"):
141-
staleness_bound = Config().server.staleness_bound
142-
143-
staleness_factor = staleness_bound / (staleness + staleness_bound)
144-
145-
return staleness_factor
40+
trainer = self.require_trainer()
41+
trainer.save_model(filename)

0 commit comments

Comments
 (0)