|
1 | 1 | """ |
2 | | -A federated learning server using Active Federated Learning, where in each round |
3 | | -clients are selected not uniformly at random, but with a probability conditioned |
4 | | -on the current model, as well as the data on the client, to maximize efficiency. |
| 2 | +A federated learning server using Active Federated Learning with the |
| 3 | +strategy-based server API. |
| 4 | +
|
| 5 | +Clients are sampled according to valuation metrics computed on the client. |
5 | 6 |
|
6 | 7 | Reference: |
7 | 8 |
|
|
10 | 11 | https://arxiv.org/pdf/1909.12641.pdf |
11 | 12 | """ |
12 | 13 |
|
13 | | -import logging |
14 | | -import math |
15 | | -import random |
16 | | - |
17 | | -import numpy as np |
| 14 | +from __future__ import annotations |
18 | 15 |
|
19 | 16 | from plato.config import Config |
20 | 17 | from plato.servers import fedavg |
21 | 18 |
|
| 19 | +from afl_selection_strategy import AFLSelectionStrategy |
| 20 | + |
22 | 21 |
|
23 | 22 | class Server(fedavg.Server): |
24 | | - """A federated learning server using the AFL algorithm.""" |
| 23 | + """An AFL server configured with the strategy-based client selection API.""" |
25 | 24 |
|
26 | 25 | def __init__( |
27 | 26 | self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None |
28 | 27 | ): |
| 28 | + algo_cfg = getattr(Config(), "algorithm", None) |
| 29 | + |
| 30 | + selection_strategy = AFLSelectionStrategy( |
| 31 | + alpha1=getattr(algo_cfg, "alpha1", 0.75) if algo_cfg else 0.75, |
| 32 | + alpha2=getattr(algo_cfg, "alpha2", 0.01) if algo_cfg else 0.01, |
| 33 | + alpha3=getattr(algo_cfg, "alpha3", 0.1) if algo_cfg else 0.1, |
| 34 | + ) |
| 35 | + |
29 | 36 | super().__init__( |
30 | 37 | model=model, |
31 | 38 | datasource=datasource, |
32 | 39 | algorithm=algorithm, |
33 | 40 | trainer=trainer, |
34 | 41 | callbacks=callbacks, |
| 42 | + client_selection_strategy=selection_strategy, |
35 | 43 | ) |
36 | | - |
37 | | - self.local_values = {} |
38 | | - |
39 | | - def weights_aggregated(self, updates): |
40 | | - """Extract required information from client reports after aggregating weights.""" |
41 | | - for update in updates: |
42 | | - self.local_values[update.client_id]["valuation"] = update.report.valuation |
43 | | - |
44 | | - def calc_sample_distribution(self, clients_pool): |
45 | | - """Calculate the sampling probability of each client for the next round.""" |
46 | | - # First, initialize valuations and probabilities when new clients are connected |
47 | | - for client_id in clients_pool: |
48 | | - if client_id not in self.local_values: |
49 | | - self.local_values[client_id] = {} |
50 | | - self.local_values[client_id]["valuation"] = -float("inf") |
51 | | - self.local_values[client_id]["prob"] = 0.0 |
52 | | - |
53 | | - # For a proportion of clients with the smallest valuations, reset these valuations |
54 | | - # to negative infinities |
55 | | - num_smallest = int(Config().algorithm.alpha1 * len(clients_pool)) |
56 | | - smallest_valuations = dict( |
57 | | - sorted(self.local_values.items(), key=lambda item: item[1]["valuation"])[ |
58 | | - :num_smallest |
59 | | - ] |
60 | | - ) |
61 | | - |
62 | | - for client_id in smallest_valuations.keys(): |
63 | | - self.local_values[client_id]["valuation"] = -float("inf") |
64 | | - |
65 | | - for client_id in clients_pool: |
66 | | - self.local_values[client_id]["prob"] = math.exp( |
67 | | - Config().algorithm.alpha2 * self.local_values[client_id]["valuation"] |
68 | | - ) |
69 | | - |
70 | | - def choose_clients(self, clients_pool, clients_count): |
71 | | - """Choose a subset of the clients to participate in each round.""" |
72 | | - assert clients_count <= len(clients_pool) |
73 | | - random.setstate(self.prng_state) |
74 | | - # Update the clients sampling distribution |
75 | | - self.calc_sample_distribution(clients_pool) |
76 | | - |
77 | | - # 1. Sample a subset of the clients according to the sampling distribution |
78 | | - num1 = int(math.floor((1 - Config().algorithm.alpha3) * clients_count)) |
79 | | - weighted_candidates = [ |
80 | | - client_id |
81 | | - for client_id in clients_pool |
82 | | - if self.local_values[client_id]["prob"] > 0.0 |
83 | | - ] |
84 | | - num1 = min(num1, len(weighted_candidates)) |
85 | | - |
86 | | - subset1 = [] |
87 | | - if num1 > 0: |
88 | | - probs = np.array( |
89 | | - [ |
90 | | - self.local_values[client_id]["prob"] |
91 | | - for client_id in weighted_candidates |
92 | | - ] |
93 | | - ) |
94 | | - total_prob = probs.sum() |
95 | | - if total_prob <= 0: |
96 | | - probs = np.ones(len(weighted_candidates), dtype=float) / len( |
97 | | - weighted_candidates |
98 | | - ) |
99 | | - else: |
100 | | - probs = probs / total_prob |
101 | | - subset1 = np.random.choice( |
102 | | - weighted_candidates, num1, p=probs, replace=False |
103 | | - ).tolist() |
104 | | - |
105 | | - # 2. Sample a subset of the remaining clients uniformly at random |
106 | | - num2 = clients_count - len(subset1) |
107 | | - remaining = clients_pool.copy() |
108 | | - |
109 | | - for client_id in subset1: |
110 | | - remaining.remove(client_id) |
111 | | - |
112 | | - subset2 = random.sample(remaining, num2) if num2 > 0 else [] |
113 | | - |
114 | | - # 3. Selected clients are the union of these two subsets |
115 | | - selected_clients = subset1 + subset2 |
116 | | - |
117 | | - self.prng_state = random.getstate() |
118 | | - logging.info("[%s] Selected clients: %s", self, selected_clients) |
119 | | - return selected_clients |
0 commit comments