Skip to content

Commit ef91051

Browse files
committed
Fixed type checking issues in examples/three_layer_fl, and in
examples/unlearning.
1 parent 1880cff commit ef91051

File tree

6 files changed

+100
-37
lines changed

6 files changed

+100
-37
lines changed

examples/three_layer_fl/fedsaw/fedsaw_algorithm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ def prune_weight_updates(
4747
)
4848

4949
# Clone the reference model to host update tensors for pruning.
50-
delta_model = copy.deepcopy(self.model.cpu())
50+
model = self.model
51+
if model is None:
52+
raise RuntimeError("Model must be initialised before pruning updates.")
53+
54+
delta_model = copy.deepcopy(model).cpu()
5155
cpu_updates: MutableMapping[str, torch.Tensor] = OrderedDict(
5256
(name, tensor.detach().cpu()) for name, tensor in updates.items()
5357
)

examples/three_layer_fl/fedsaw/fedsaw_client.py

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
A federated learning client using pruning.
33
"""
44

5+
from __future__ import annotations
6+
57
import copy
68
import logging
9+
from collections.abc import Mapping
10+
from typing import Any
711

812
from fedsaw_algorithm import Algorithm as FedSawAlgorithm
913

@@ -20,10 +24,12 @@ class FedSawClientLifecycleStrategy(DefaultLifecycleStrategy):
2024
_STATE_KEY = "fedsaw_client"
2125

2226
@staticmethod
23-
def _state(context):
27+
def _state(context: ClientContext) -> dict[str, Any]:
2428
return context.state.setdefault(FedSawClientLifecycleStrategy._STATE_KEY, {})
2529

26-
def process_server_response(self, context, server_response):
30+
def process_server_response(
31+
self, context: ClientContext, server_response: dict[str, Any]
32+
) -> None:
2733
super().process_server_response(context, server_response)
2834
amount = server_response.get("pruning_amount")
2935
if amount is None:
@@ -33,17 +39,17 @@ def process_server_response(self, context, server_response):
3339
state["pruning_amount"] = amount
3440

3541
owner = context.owner
36-
if owner is not None:
37-
owner.pruning_amount = amount
42+
if isinstance(owner, FedSawClient) and isinstance(amount, (int, float)):
43+
owner.pruning_amount = float(amount)
3844

3945

4046
class FedSawTrainingStrategy(DefaultTrainingStrategy):
4147
"""Training strategy that prunes local updates before transmission."""
4248

4349
async def train(self, context: ClientContext):
4450
algorithm = context.algorithm
45-
if algorithm is None:
46-
raise RuntimeError("Algorithm is required for FedSaw training.")
51+
if not isinstance(algorithm, FedSawAlgorithm):
52+
raise RuntimeError("FedSaw training requires a FedSaw algorithm instance.")
4753

4854
previous_weights = copy.deepcopy(algorithm.extract_weights())
4955
report, new_weights = await super().train(context)
@@ -53,25 +59,69 @@ async def train(self, context: ClientContext):
5359

5460
return report, weight_updates
5561

56-
def _prune_updates(self, context, previous_weights, new_weights):
62+
def _prune_updates(
63+
self,
64+
context: ClientContext,
65+
previous_weights: Mapping[str, Any],
66+
new_weights: Mapping[str, Any],
67+
):
5768
algorithm = context.algorithm
69+
if not isinstance(algorithm, FedSawAlgorithm):
70+
raise RuntimeError("FedSaw algorithm required to prune weight updates.")
71+
5872
updates = algorithm.compute_weight_updates(previous_weights, new_weights)
5973

6074
pruning_method = (
6175
"random"
6276
if getattr(Config().clients, "pruning_method", None) == "random"
6377
else "l1"
6478
)
65-
pruning_amount = getattr(context.owner, "pruning_amount", None)
79+
owner = context.owner
80+
pruning_amount: float | int | None = None
81+
if isinstance(owner, FedSawClient):
82+
pruning_amount = owner.pruning_amount
83+
6684
if pruning_amount is None:
6785
state = FedSawClientLifecycleStrategy._state(context)
68-
pruning_amount = state.get("pruning_amount", 0)
86+
stored_amount = state.get("pruning_amount", 0)
87+
pruning_amount = stored_amount if isinstance(stored_amount, (int, float)) else 0
6988

7089
return algorithm.prune_weight_updates(
7190
updates, amount=pruning_amount, method=pruning_method
7291
)
7392

7493

94+
class FedSawClient(simple.Client):
95+
"""Client implementation that tracks pruning metadata for FedSaw."""
96+
97+
def __init__(
98+
self,
99+
model=None,
100+
datasource=None,
101+
algorithm=None,
102+
trainer=None,
103+
callbacks=None,
104+
trainer_callbacks=None,
105+
):
106+
super().__init__(
107+
model=model,
108+
datasource=datasource,
109+
algorithm=algorithm or FedSawAlgorithm,
110+
trainer=trainer,
111+
callbacks=callbacks,
112+
trainer_callbacks=trainer_callbacks,
113+
)
114+
self.pruning_amount: float = 0.0
115+
116+
self._configure_composable(
117+
lifecycle_strategy=FedSawClientLifecycleStrategy(),
118+
payload_strategy=self.payload_strategy,
119+
training_strategy=FedSawTrainingStrategy(),
120+
reporting_strategy=self.reporting_strategy,
121+
communication_strategy=self.communication_strategy,
122+
)
123+
124+
75125
def create_client(
76126
*,
77127
model=None,
@@ -80,28 +130,17 @@ def create_client(
80130
trainer=None,
81131
callbacks=None,
82132
trainer_callbacks=None,
83-
):
133+
) -> FedSawClient:
84134
"""Build a FedSaw client that prunes its updates before reporting."""
85-
client = simple.Client(
135+
return FedSawClient(
86136
model=model,
87137
datasource=datasource,
88-
algorithm=algorithm or FedSawAlgorithm,
138+
algorithm=algorithm,
89139
trainer=trainer,
90140
callbacks=callbacks,
91141
trainer_callbacks=trainer_callbacks,
92142
)
93-
client.pruning_amount = 0
94-
95-
client._configure_composable(
96-
lifecycle_strategy=FedSawClientLifecycleStrategy(),
97-
payload_strategy=client.payload_strategy,
98-
training_strategy=FedSawTrainingStrategy(),
99-
reporting_strategy=client.reporting_strategy,
100-
communication_strategy=client.communication_strategy,
101-
)
102-
103-
return client
104143

105144

106145
# Maintain compatibility for imports expecting a Client callable/class.
107-
Client = create_client
146+
Client = FedSawClient

examples/three_layer_fl/fedsaw/fedsaw_server.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
as either central or edge servers.
44
"""
55

6+
from __future__ import annotations
7+
68
import math
79
import statistics
810

@@ -20,7 +22,7 @@ def __init__(self, model=None, algorithm=None, trainer=None):
2022
super().__init__(model=model, algorithm=selected_algorithm, trainer=trainer)
2123

2224
# The central server uses a list to store each edge server's clients' pruning amount
23-
self.pruning_amount_list = None
25+
self.pruning_amount_list: dict[int, float] | None = None
2426

2527
if Config().is_central_server():
2628
init_pruning_amount = (
@@ -36,8 +38,9 @@ def __init__(self, model=None, algorithm=None, trainer=None):
3638
)
3739
}
3840

41+
self.edge_pruning_amount: float = 0.0
3942
if Config().is_edge_server():
40-
self.edge_pruning_amount = 0
43+
self.edge_pruning_amount = 0.0
4144

4245
def customize_server_response(self, server_response: dict, client_id) -> dict:
4346
"""Wraps up generating the server response with any additional information."""
@@ -55,7 +58,8 @@ def customize_server_response(self, server_response: dict, client_id) -> dict:
5558
async def aggregate_weights(self, updates, baseline_weights, weights_received):
5659
"""Aggregates the reported weight updates from the selected clients."""
5760
deltas = await self.aggregate_deltas(updates, weights_received)
58-
updated_weights = self.algorithm.update_weights(deltas)
61+
algorithm = self.require_algorithm()
62+
updated_weights = algorithm.update_weights(deltas)
5963
return updated_weights
6064

6165
def update_pruning_amount_list(self):
@@ -64,6 +68,11 @@ def update_pruning_amount_list(self):
6468

6569
median = statistics.median(weights_diff_list)
6670

71+
if self.pruning_amount_list is None:
72+
raise RuntimeError(
73+
"Pruning amount list is unavailable on this server instance."
74+
)
75+
6776
for client_id in weights_diff_dict:
6877
if weights_diff_dict[client_id]:
6978
self.pruning_amount_list[client_id] = 1 / (
@@ -84,13 +93,14 @@ def get_weights_differences(self):
8493
}
8594

8695
weights_diff_list = []
96+
algorithm = self.require_algorithm()
8797

8898
for update in self.updates:
8999
client_id = update.report.client_id
90100
num_samples = update.report.num_samples
91101
received_updates = update.payload
92102

93-
weights_diff = self.algorithm.compute_weight_difference(
103+
weights_diff = algorithm.compute_weight_difference(
94104
received_updates,
95105
num_samples=num_samples,
96106
total_samples=self.total_samples,

examples/three_layer_fl/tempo/tempo_server.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
clients' local epoch numbers of each edge server (institution).
44
"""
55

6+
from __future__ import annotations
7+
68
import math
79

810
import torch
@@ -21,7 +23,7 @@ def __init__(self):
2123
super().__init__()
2224

2325
# The central server uses a list to store each edge server's clients' local epoch numbers
24-
self.local_epoch_list = None
26+
self.local_epoch_list: list[int] | None = None
2527
if Config().is_central_server():
2628
self.local_epoch_list = [
2729
Config().trainer.epochs for i in range(Config().algorithm.total_silos)
@@ -111,10 +113,11 @@ def compute_weights_difference(self, local_weights, num_samples):
111113
Computes the weight difference of an edge server's aggregated model
112114
and the global model.
113115
"""
114-
weights_diff = 0
116+
weights_diff = 0.0
115117

116118
# Extract global model weights
117-
global_weights = self.algorithm.extract_weights()
119+
algorithm = self.require_algorithm()
120+
global_weights = algorithm.extract_weights()
118121

119122
for name, local_weight in local_weights.items():
120123
global_weight = global_weights[name]

examples/unlearning/knot/solver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
the number of papers
77
"""
88

9+
from typing import Any, cast
10+
911
try:
1012
import mosek # type: ignore
1113
except ImportError: # pragma: no cover - optional dependency
12-
mosek = None
14+
mosek = cast(Any, None)
1315

1416
from cvxopt import matrix, solvers, sparse, spmatrix
1517

plato/config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,19 @@ def simulate_client_speed() -> np.ndarray:
473473
@staticmethod
474474
def is_edge_server() -> bool:
475475
"""Returns whether the current instance is an edge server in cross-silo FL."""
476-
return Config().args.port is not None and bool(
477-
getattr(Config().algorithm, "cross_silo", False)
478-
)
476+
if not bool(getattr(Config().algorithm, "cross_silo", False)):
477+
return False
478+
479+
args = Config().args
480+
return args.port is not None and args.id is not None
479481

480482
@staticmethod
481483
def is_central_server() -> bool:
482484
"""Returns whether the current instance is a central server in cross-silo FL."""
483-
return hasattr(Config().algorithm, "cross_silo") and Config().args.port is None
485+
if not bool(getattr(Config().algorithm, "cross_silo", False)):
486+
return False
487+
488+
return Config().args.id is None
484489

485490
@staticmethod
486491
def gpu_count() -> int:

0 commit comments

Comments
 (0)