Skip to content

Commit d872c0a

Browse files
committed
Ty checks passed on examples/client_selection; migrated AFL to new server API.
1 parent 2178c82 commit d872c0a

File tree

11 files changed

+117
-176
lines changed

11 files changed

+117
-176
lines changed

docs/mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ theme:
88
repo: fontawesome/brands/github
99
font:
1010
text: Inter
11-
code: IBM Plex Mono
11+
code: Google Sans Code
1212
palette:
1313
# Light mode
1414
- scheme: astral-light

examples/client_selection/afl/afl_callbacks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from __future__ import annotations
66

77
import logging
8-
from collections.abc import Iterable
8+
from collections.abc import Iterable, Sized
9+
from typing import Any
910

1011
import torch
1112

@@ -86,12 +87,11 @@ def on_train_epoch_start(self, trainer, config, **kwargs):
8687
self._recorded = True
8788

8889
@staticmethod
89-
def _has_batches(loader: Iterable) -> bool:
90+
def _has_batches(loader: Iterable[Any] | Sized) -> bool:
9091
"""Best-effort check that the data loader yields at least one batch."""
91-
length = None
92-
if hasattr(loader, "__len__"):
92+
if isinstance(loader, Sized):
9393
try:
94-
length = len(loader)
94+
return len(loader) > 0
9595
except TypeError:
96-
length = None
97-
return bool(length) if length is not None else True
96+
return True
97+
return True
Lines changed: 17 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""
2-
A federated learning server using Active Federated Learning, where in each round
3-
clients are selected not uniformly at random, but with a probability conditioned
4-
on the current model, as well as the data on the client, to maximize efficiency.
2+
A federated learning server using Active Federated Learning with the
3+
strategy-based server API.
4+
5+
Clients are sampled according to valuation metrics computed on the client.
56
67
Reference:
78
@@ -10,110 +11,33 @@
1011
https://arxiv.org/pdf/1909.12641.pdf
1112
"""
1213

13-
import logging
14-
import math
15-
import random
16-
17-
import numpy as np
14+
from __future__ import annotations
1815

1916
from plato.config import Config
2017
from plato.servers import fedavg
2118

19+
from afl_selection_strategy import AFLSelectionStrategy
20+
2221

2322
class Server(fedavg.Server):
24-
"""A federated learning server using the AFL algorithm."""
23+
"""An AFL server configured with the strategy-based client selection API."""
2524

2625
def __init__(
2726
self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
2827
):
28+
algo_cfg = getattr(Config(), "algorithm", None)
29+
30+
selection_strategy = AFLSelectionStrategy(
31+
alpha1=getattr(algo_cfg, "alpha1", 0.75) if algo_cfg else 0.75,
32+
alpha2=getattr(algo_cfg, "alpha2", 0.01) if algo_cfg else 0.01,
33+
alpha3=getattr(algo_cfg, "alpha3", 0.1) if algo_cfg else 0.1,
34+
)
35+
2936
super().__init__(
3037
model=model,
3138
datasource=datasource,
3239
algorithm=algorithm,
3340
trainer=trainer,
3441
callbacks=callbacks,
42+
client_selection_strategy=selection_strategy,
3543
)
36-
37-
self.local_values = {}
38-
39-
def weights_aggregated(self, updates):
40-
"""Extract required information from client reports after aggregating weights."""
41-
for update in updates:
42-
self.local_values[update.client_id]["valuation"] = update.report.valuation
43-
44-
def calc_sample_distribution(self, clients_pool):
45-
"""Calculate the sampling probability of each client for the next round."""
46-
# First, initialize valuations and probabilities when new clients are connected
47-
for client_id in clients_pool:
48-
if client_id not in self.local_values:
49-
self.local_values[client_id] = {}
50-
self.local_values[client_id]["valuation"] = -float("inf")
51-
self.local_values[client_id]["prob"] = 0.0
52-
53-
# For a proportion of clients with the smallest valuations, reset these valuations
54-
# to negative infinities
55-
num_smallest = int(Config().algorithm.alpha1 * len(clients_pool))
56-
smallest_valuations = dict(
57-
sorted(self.local_values.items(), key=lambda item: item[1]["valuation"])[
58-
:num_smallest
59-
]
60-
)
61-
62-
for client_id in smallest_valuations.keys():
63-
self.local_values[client_id]["valuation"] = -float("inf")
64-
65-
for client_id in clients_pool:
66-
self.local_values[client_id]["prob"] = math.exp(
67-
Config().algorithm.alpha2 * self.local_values[client_id]["valuation"]
68-
)
69-
70-
def choose_clients(self, clients_pool, clients_count):
71-
"""Choose a subset of the clients to participate in each round."""
72-
assert clients_count <= len(clients_pool)
73-
random.setstate(self.prng_state)
74-
# Update the clients sampling distribution
75-
self.calc_sample_distribution(clients_pool)
76-
77-
# 1. Sample a subset of the clients according to the sampling distribution
78-
num1 = int(math.floor((1 - Config().algorithm.alpha3) * clients_count))
79-
weighted_candidates = [
80-
client_id
81-
for client_id in clients_pool
82-
if self.local_values[client_id]["prob"] > 0.0
83-
]
84-
num1 = min(num1, len(weighted_candidates))
85-
86-
subset1 = []
87-
if num1 > 0:
88-
probs = np.array(
89-
[
90-
self.local_values[client_id]["prob"]
91-
for client_id in weighted_candidates
92-
]
93-
)
94-
total_prob = probs.sum()
95-
if total_prob <= 0:
96-
probs = np.ones(len(weighted_candidates), dtype=float) / len(
97-
weighted_candidates
98-
)
99-
else:
100-
probs = probs / total_prob
101-
subset1 = np.random.choice(
102-
weighted_candidates, num1, p=probs, replace=False
103-
).tolist()
104-
105-
# 2. Sample a subset of the remaining clients uniformly at random
106-
num2 = clients_count - len(subset1)
107-
remaining = clients_pool.copy()
108-
109-
for client_id in subset1:
110-
remaining.remove(client_id)
111-
112-
subset2 = random.sample(remaining, num2) if num2 > 0 else []
113-
114-
# 3. Selected clients are the union of these two subsets
115-
selected_clients = subset1 + subset2
116-
117-
self.prng_state = random.getstate()
118-
logging.info("[%s] Selected clients: %s", self, selected_clients)
119-
return selected_clients
Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,10 @@
11
"""
2-
A federated learning server using Active Federated Learning with strategy pattern.
2+
Legacy entry point for the AFL server using the strategy-based API.
33
4-
This is the updated version using the strategy-based API instead of inheritance.
5-
6-
Reference:
7-
8-
Goetz et al., "Active Federated Learning", 2019.
9-
10-
https://arxiv.org/pdf/1909.12641.pdf
4+
This module simply re-exports the server defined in ``afl_server`` to avoid
5+
breaking older entry points.
116
"""
127

13-
from plato.config import Config
14-
from plato.servers import fedavg
15-
from plato.servers.strategies import AFLSelectionStrategy
16-
17-
18-
class Server(fedavg.Server):
19-
"""A federated learning server using the AFL client selection strategy."""
20-
21-
def __init__(
22-
self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
23-
):
24-
# Load AFL parameters from config
25-
alpha1 = Config().algorithm.alpha1
26-
alpha2 = Config().algorithm.alpha2
27-
alpha3 = Config().algorithm.alpha3
8+
from afl_server import Server
289

29-
super().__init__(
30-
model=model,
31-
datasource=datasource,
32-
algorithm=algorithm,
33-
trainer=trainer,
34-
callbacks=callbacks,
35-
client_selection_strategy=AFLSelectionStrategy(
36-
alpha1=alpha1,
37-
alpha2=alpha2,
38-
alpha3=alpha3,
39-
),
40-
)
10+
__all__ = ["Server"]

examples/client_selection/afl/afl_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
"""
1212

1313
import afl_client
14-
import afl_server_strategy
14+
import afl_server
1515

1616

1717
def main():
1818
"""A Plato federated learning training session using AFL strategy."""
1919
client = afl_client.create_client()
20-
server = afl_server_strategy.Server()
20+
server = afl_server.Server()
2121
server.run(client)
2222

2323

examples/client_selection/oort/oort_trainer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,29 @@
88
(OSDI 2021), July 2021.
99
"""
1010

11+
from __future__ import annotations
12+
1113
import numpy as np
1214
import torch
1315
from torch import nn
1416

1517
from plato.trainers.composable import ComposableTrainer
1618
from plato.trainers.strategies.base import LossCriterionStrategy
19+
from plato.trainers.tracking import RunHistory
1720

1821

1922
class OortLossStrategy(LossCriterionStrategy):
2023
"""Loss strategy for Oort that tracks sum of squared per-sample losses."""
2124

2225
def __init__(self):
23-
self._criterion = None
24-
self._run_history = None
26+
self._criterion: nn.CrossEntropyLoss | None = None
27+
self._run_history: RunHistory | None = None
2528

2629
def setup(self, context):
2730
"""Initialize the loss criterion."""
2831
self._criterion = nn.CrossEntropyLoss(reduction="none")
2932

30-
def attach_run_history(self, run_history):
33+
def attach_run_history(self, run_history: RunHistory) -> None:
3134
"""Attach run history for metric tracking."""
3235
self._run_history = run_history
3336

@@ -38,6 +41,11 @@ def compute_loss(self, outputs, labels, context):
3841
This computes per-sample losses, tracks the sum of squares
3942
(used by Oort for client selection), and returns the mean loss.
4043
"""
44+
if self._criterion is None:
45+
raise RuntimeError(
46+
"OortLossStrategy has not been initialised. Did you call setup()?"
47+
)
48+
4149
per_sample_loss = self._criterion(outputs, labels)
4250

4351
if self._run_history is not None:
@@ -71,5 +79,6 @@ def __init__(self, model=None, callbacks=None):
7179
loss_strategy=loss_strategy,
7280
)
7381

74-
if hasattr(self.loss_strategy, "attach_run_history"):
75-
self.loss_strategy.attach_run_history(self.run_history)
82+
attach_run_history = getattr(self.loss_strategy, "attach_run_history", None)
83+
if callable(attach_run_history):
84+
attach_run_history(self.run_history)

examples/client_selection/pisces/pisces_aggregation_strategy.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import asyncio
1010
import logging
1111
from types import SimpleNamespace
12-
from typing import Dict, List
13-
1412
import numpy as np
1513

1614
from plato.config import Config
@@ -52,10 +50,22 @@ async def aggregate_deltas(
5250

5351
total_samples = sum(update.report.num_samples for update in updates)
5452

55-
avg_update = {
56-
name: context.trainer.zeros(delta.shape)
57-
for name, delta in deltas_received[0].items()
58-
}
53+
trainer = getattr(context, "trainer", None)
54+
zeros_fn = getattr(trainer, "zeros", None) if trainer is not None else None
55+
56+
avg_update = {}
57+
for name, delta in deltas_received[0].items():
58+
if callable(zeros_fn):
59+
avg_update[name] = zeros_fn(delta.shape)
60+
elif hasattr(delta, "clone") and callable(getattr(delta, "clone")):
61+
cloned = delta.clone()
62+
if hasattr(cloned, "zero_") and callable(getattr(cloned, "zero_")):
63+
cloned.zero_()
64+
avg_update[name] = cloned
65+
else:
66+
avg_update[name] = delta * 0
67+
else:
68+
avg_update[name] = np.zeros_like(delta)
5969

6070
for i, delta in enumerate(deltas_received):
6171
client_id = updates[i].client_id

examples/client_selection/pisces/pisces_trainer.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,31 @@
99
URL: https://arxiv.org/abs/2206.09264
1010
"""
1111

12+
from __future__ import annotations
13+
14+
from collections.abc import Callable
15+
from typing import Any
16+
17+
import torch
18+
1219
from plato.trainers import loss_criterion
1320
from plato.trainers.composable import ComposableTrainer
1421
from plato.trainers.strategies.base import LossCriterionStrategy
22+
from plato.trainers.tracking import RunHistory
1523

1624

1725
class PiscesLossStrategy(LossCriterionStrategy):
1826
"""Loss strategy for Pisces that tracks per-batch loss values."""
1927

2028
def __init__(self):
21-
self._criterion = None
22-
self._run_history = None
29+
self._criterion: Callable[[Any, Any], torch.Tensor] | None = None
30+
self._run_history: RunHistory | None = None
2331

2432
def setup(self, context):
2533
"""Initialize the loss criterion."""
2634
self._criterion = loss_criterion.get()
2735

28-
def attach_run_history(self, run_history):
36+
def attach_run_history(self, run_history: RunHistory) -> None:
2937
"""Attach run history for metric tracking."""
3038
self._run_history = run_history
3139

@@ -36,6 +44,11 @@ def compute_loss(self, outputs, labels, context):
3644
This computes the batch loss and stores it in run_history
3745
for Pisces client selection algorithm.
3846
"""
47+
if self._criterion is None:
48+
raise RuntimeError(
49+
"PiscesLossStrategy has not been initialised. Did you call setup()?"
50+
)
51+
3952
per_batch_loss = self._criterion(outputs, labels)
4053

4154
current_epoch = getattr(context, "current_epoch", 1)
@@ -67,5 +80,6 @@ def __init__(self, model=None, callbacks=None):
6780
loss_strategy=loss_strategy,
6881
)
6982

70-
if hasattr(self.loss_strategy, "attach_run_history"):
71-
self.loss_strategy.attach_run_history(self.run_history)
83+
attach_run_history = getattr(self.loss_strategy, "attach_run_history", None)
84+
if callable(attach_run_history):
85+
attach_run_history(self.run_history)

0 commit comments

Comments
 (0)