Skip to content

Commit aa1fccf

Browse files
committed
Added pFedGraph to Plato.
1 parent 223da9a commit aa1fccf

File tree

16 files changed

+700
-0
lines changed

16 files changed

+700
-0
lines changed

docs/docs/configurations/algorithm.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- `fedavg` the federated averaging algorithm
77
- `split_learning` the Split Learning algorithm
88
- `fedavg_personalized` the personalized federated learning algorithm
9+
- `pfedgraph` the Personalized Federated Learning with Inferred Collaboration Graphs algorithm
910

1011
!!! example "cross_silo"
1112
Whether or not cross-silo training should be used.
@@ -26,3 +27,27 @@
2627
A float to show the proportion of clients participating in the federated training process. It is under `personalization`, which is a sub-config path that contains other personalized training parameters.
2728

2829
Default value: `1.0`
30+
31+
!!! example "pfedgraph"
32+
Configuration for pFedGraph.
33+
34+
!!! example "pfedgraph_alpha"
35+
Hyper-parameter controlling the collaboration graph update.
36+
37+
Default value: `0.8`
38+
39+
!!! example "pfedgraph_lambda"
40+
Regularization strength for cosine similarity in the local objective.
41+
42+
Default value: `0.01`
43+
44+
!!! example "pfedgraph_similarity_metric"
45+
Similarity metric scope for graph inference. Use `all` for all parameters
46+
or `fc` to focus on the final fully-connected layers.
47+
48+
Default value: `all`
49+
50+
!!! example "pfedgraph_similarity_layers"
51+
Optional list of layer name substrings to use when computing model
52+
similarity for graph inference. Overrides `pfedgraph_similarity_metric`
53+
when provided.

docs/docs/configurations/server.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- `fedavg_mpc_shamir` a Federated Averaging server that reconstructs Shamir MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_shamir` processor.
1111
- `split_learning` a Split Learning server that supports training different kinds of models in split learning framework. When this server is used, the `clients.per_round` in the configuration should be set to 1. Users should define the rules for updating models weights before cut from the clients to the server in the callback function `on_update_weights_before_cut`, depending on the specific model they use.
1212
- `fedavg_personalized` a personalized federated learning server that starts from a number of regular rounds of federated learning. In these regular rounds, only a subset of the total clients can be selected to perform the local update (the ratio of which is a configuration setting). After all regular rounds are completed, it starts a final round of personalization, where a selected subset of clients perform local training using their local dataset.
13+
- `pfedgraph` a personalized federated learning server that aggregates models using an inferred collaboration graph and sends per-client aggregated weights.
1314

1415
!!! example "address"
1516
The address of the central server, such as `127.0.0.1`.

plato/algorithms/pfedgraph.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
pFedGraph algorithm wrapper.
3+
4+
pFedGraph uses standard FedAvg weight handling with a personalized aggregation
5+
strategy on the server side.
6+
"""
7+
8+
from plato.algorithms import fedavg
9+
10+
11+
class Algorithm(fedavg.Algorithm):
12+
"""pFedGraph reuses the FedAvg algorithm primitives."""

plato/algorithms/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
fedavg_personalized,
1717
lora,
1818
mlx_fedavg,
19+
pfedgraph,
1920
split_learning,
2021
)
2122
from plato.algorithms.base import Algorithm as AlgorithmBase
@@ -27,6 +28,7 @@
2728
"fedavg_personalized": fedavg_personalized.Algorithm,
2829
"fedavg_lora": lora.Algorithm,
2930
"mlx_fedavg": mlx_fedavg.Algorithm,
31+
"pfedgraph": pfedgraph.Algorithm,
3032
"split_learning": split_learning.Algorithm,
3133
}
3234

plato/clients/pfedgraph.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
pFedGraph client implementation.
3+
4+
Uses the standard simple client pipeline.
5+
"""
6+
7+
from plato.clients import simple
8+
9+
10+
class Client(simple.Client):
11+
"""A pFedGraph client using default composable strategies."""

plato/clients/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
edge,
2020
fedavg_personalized,
2121
mpc,
22+
pfedgraph,
2223
self_supervised_learning,
2324
simple,
2425
split_learning,
@@ -107,6 +108,7 @@ def factory(**kwargs) -> Client:
107108
"simple": _simple_like_factory(simple.Client),
108109
"fedavg_personalized": _simple_like_factory(fedavg_personalized.Client),
109110
"mpc": _simple_like_factory(mpc.Client),
111+
"pfedgraph": _simple_like_factory(pfedgraph.Client),
110112
"self_supervised_learning": _simple_like_factory(self_supervised_learning.Client),
111113
"split_learning": _simple_like_factory(split_learning.Client),
112114
"edge": _edge_factory(),

plato/servers/pfedgraph.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
pFedGraph server implementation.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from typing import Any, Sequence
8+
9+
from plato.config import Config
10+
from plato.servers import fedavg
11+
from plato.servers.strategies.aggregation.pfedgraph import (
12+
PFedGraphAggregationStrategy,
13+
)
14+
15+
16+
class Server(fedavg.Server):
17+
"""Federated learning server implementing pFedGraph."""
18+
19+
def __init__(
20+
self,
21+
model=None,
22+
datasource=None,
23+
algorithm=None,
24+
trainer=None,
25+
callbacks=None,
26+
aggregation_strategy=None,
27+
client_selection_strategy=None,
28+
):
29+
if aggregation_strategy is None:
30+
similarity_layers = None
31+
similarity_metric = "all"
32+
alpha = 0.8
33+
34+
if hasattr(Config(), "algorithm"):
35+
if hasattr(Config().algorithm, "pfedgraph_similarity_metric"):
36+
similarity_metric = Config().algorithm.pfedgraph_similarity_metric
37+
elif hasattr(Config().algorithm, "pfedgraph_similarity"):
38+
similarity_metric = Config().algorithm.pfedgraph_similarity
39+
40+
if hasattr(Config().algorithm, "pfedgraph_similarity_layers"):
41+
similarity_layers = Config().algorithm.pfedgraph_similarity_layers
42+
43+
if hasattr(Config().algorithm, "pfedgraph_alpha"):
44+
alpha = Config().algorithm.pfedgraph_alpha
45+
46+
aggregation_strategy = PFedGraphAggregationStrategy(
47+
alpha=alpha,
48+
similarity_metric=similarity_metric,
49+
similarity_layers=similarity_layers,
50+
)
51+
52+
super().__init__(
53+
model=model,
54+
datasource=datasource,
55+
algorithm=algorithm,
56+
trainer=trainer,
57+
callbacks=callbacks,
58+
aggregation_strategy=aggregation_strategy,
59+
client_selection_strategy=client_selection_strategy,
60+
)
61+
62+
self.client_models: dict[int, dict[str, Any]] = {}
63+
64+
def update_client_model(
65+
self,
66+
aggregated_clients_models: Sequence[dict[str, Any]],
67+
updates: Sequence[Any],
68+
) -> None:
69+
"""Update the stored model for each client."""
70+
for client_model, update in zip(aggregated_clients_models, updates):
71+
client_id = getattr(update, "client_id", None)
72+
if client_id is None:
73+
continue
74+
self.client_models[client_id] = client_model
75+
76+
def customize_server_payload(self, payload: Any) -> Any:
77+
"""Send per-client aggregated weights when available."""
78+
client_id = self.selected_client_id
79+
if client_id in self.client_models:
80+
return self.client_models[client_id]
81+
return payload

plato/servers/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
fedavg_mpc_additive,
1717
fedavg_mpc_shamir,
1818
fedavg_personalized,
19+
pfedgraph,
1920
split_learning,
2021
)
2122

@@ -34,6 +35,7 @@
3435
"fedavg_personalized": fedavg_personalized.Server,
3536
"fedavg_mpc_additive": fedavg_mpc_additive.Server,
3637
"fedavg_mpc_shamir": fedavg_mpc_shamir.Server,
38+
"pfedgraph": pfedgraph.Server,
3739
"split_learning": split_learning.Server,
3840
}
3941

plato/servers/strategies/aggregation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from plato.servers.strategies.aggregation.gan import FedAvgGanAggregationStrategy
1212
from plato.servers.strategies.aggregation.he import FedAvgHEAggregationStrategy
1313
from plato.servers.strategies.aggregation.hermes import HermesAggregationStrategy
14+
from plato.servers.strategies.aggregation.pfedgraph import PFedGraphAggregationStrategy
1415
from plato.servers.strategies.aggregation.port import PortAggregationStrategy
1516

1617
__all__ = [
@@ -20,6 +21,7 @@
2021
"FedAsyncAggregationStrategy",
2122
"PortAggregationStrategy",
2223
"HermesAggregationStrategy",
24+
"PFedGraphAggregationStrategy",
2325
"FedAvgGanAggregationStrategy",
2426
"FedAvgHEAggregationStrategy",
2527
]

0 commit comments

Comments
 (0)