|
| 1 | +""" |
| 2 | +pFedGraph server implementation. |
| 3 | +""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +from typing import Any, Sequence |
| 8 | + |
| 9 | +from plato.config import Config |
| 10 | +from plato.servers import fedavg |
| 11 | +from plato.servers.strategies.aggregation.pfedgraph import ( |
| 12 | + PFedGraphAggregationStrategy, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +class Server(fedavg.Server): |
| 17 | + """Federated learning server implementing pFedGraph.""" |
| 18 | + |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + model=None, |
| 22 | + datasource=None, |
| 23 | + algorithm=None, |
| 24 | + trainer=None, |
| 25 | + callbacks=None, |
| 26 | + aggregation_strategy=None, |
| 27 | + client_selection_strategy=None, |
| 28 | + ): |
| 29 | + if aggregation_strategy is None: |
| 30 | + similarity_layers = None |
| 31 | + similarity_metric = "all" |
| 32 | + alpha = 0.8 |
| 33 | + |
| 34 | + if hasattr(Config(), "algorithm"): |
| 35 | + if hasattr(Config().algorithm, "pfedgraph_similarity_metric"): |
| 36 | + similarity_metric = Config().algorithm.pfedgraph_similarity_metric |
| 37 | + elif hasattr(Config().algorithm, "pfedgraph_similarity"): |
| 38 | + similarity_metric = Config().algorithm.pfedgraph_similarity |
| 39 | + |
| 40 | + if hasattr(Config().algorithm, "pfedgraph_similarity_layers"): |
| 41 | + similarity_layers = Config().algorithm.pfedgraph_similarity_layers |
| 42 | + |
| 43 | + if hasattr(Config().algorithm, "pfedgraph_alpha"): |
| 44 | + alpha = Config().algorithm.pfedgraph_alpha |
| 45 | + |
| 46 | + aggregation_strategy = PFedGraphAggregationStrategy( |
| 47 | + alpha=alpha, |
| 48 | + similarity_metric=similarity_metric, |
| 49 | + similarity_layers=similarity_layers, |
| 50 | + ) |
| 51 | + |
| 52 | + super().__init__( |
| 53 | + model=model, |
| 54 | + datasource=datasource, |
| 55 | + algorithm=algorithm, |
| 56 | + trainer=trainer, |
| 57 | + callbacks=callbacks, |
| 58 | + aggregation_strategy=aggregation_strategy, |
| 59 | + client_selection_strategy=client_selection_strategy, |
| 60 | + ) |
| 61 | + |
| 62 | + self.client_models: dict[int, dict[str, Any]] = {} |
| 63 | + |
| 64 | + def update_client_model( |
| 65 | + self, |
| 66 | + aggregated_clients_models: Sequence[dict[str, Any]], |
| 67 | + updates: Sequence[Any], |
| 68 | + ) -> None: |
| 69 | + """Update the stored model for each client.""" |
| 70 | + for client_model, update in zip(aggregated_clients_models, updates): |
| 71 | + client_id = getattr(update, "client_id", None) |
| 72 | + if client_id is None: |
| 73 | + continue |
| 74 | + self.client_models[client_id] = client_model |
| 75 | + |
| 76 | + def customize_server_payload(self, payload: Any) -> Any: |
| 77 | + """Send per-client aggregated weights when available.""" |
| 78 | + client_id = self.selected_client_id |
| 79 | + if client_id in self.client_models: |
| 80 | + return self.client_models[client_id] |
| 81 | + return payload |
0 commit comments