Skip to content

Commit 6cec854

Browse files
committed
ty checks passed for examples/callbacks and examples/gradient_leakage_attacks.
1 parent 63dfb86 commit 6cec854

File tree

13 files changed

+549
-306
lines changed

13 files changed

+549
-306
lines changed

examples/callbacks/callback_examples.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
from plato.callbacks.server import ServerCallback
77
from plato.callbacks.trainer import TrainerCallback
88

9+
__all__ = [
10+
"argumentClientCallback",
11+
"dynamicClientCallback",
12+
"argumentServerCallback",
13+
"dynamicServerCallback",
14+
"customTrainerCallback",
15+
]
16+
917

1018
class argumentClientCallback(ClientCallback):
1119
def on_inbound_received(self, client, inbound_processor):

examples/callbacks/callbacks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
This example shows how to use callbacks to customize server, client, and trainer.
33
"""
44

5-
from callback_examples import *
5+
from callback_examples import (
6+
argumentClientCallback,
7+
argumentServerCallback,
8+
customTrainerCallback,
9+
dynamicClientCallback,
10+
dynamicServerCallback,
11+
)
612

713
from plato.clients import simple
814
from plato.servers import fedavg

examples/client_selection/afl/afl_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313

1414
from __future__ import annotations
1515

16+
from afl_selection_strategy import AFLSelectionStrategy
17+
1618
from plato.config import Config
1719
from plato.servers import fedavg
1820

19-
from afl_selection_strategy import AFLSelectionStrategy
20-
2121

2222
class Server(fedavg.Server):
2323
"""An AFL server configured with the strategy-based client selection API."""

examples/client_selection/pisces/pisces_aggregation_strategy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import asyncio
1010
import logging
1111
from types import SimpleNamespace
12+
1213
import numpy as np
1314

1415
from plato.config import Config

examples/composable_trainer/composable_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ def setup(self, context: TrainingContext):
216216
def compute_loss(self, outputs, labels, context):
217217
"""Compute weighted cross-entropy loss + L2 regularization."""
218218
if self._criterion is None:
219-
raise RuntimeError("Loss criterion is not initialized. Call setup first.")
219+
raise RuntimeError(
220+
"Loss criterion is not initialized. Call setup first."
221+
)
220222
# Base cross-entropy loss with class weights
221223
ce_loss = self._criterion(outputs, labels)
222224

examples/gradient_leakage_attacks/defense/GradDefense/compensate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import torch
13+
from numpy.typing import ArrayLike, NDArray
1314

1415

1516
def get_factor(num):
@@ -45,9 +46,10 @@ def get_matrix_size(total_params_num: int, q: float):
4546
return gradients_matrix_v, gradients_matrix_w, real_q
4647

4748

48-
def get_covariance_matrix(matrix):
49+
def get_covariance_matrix(matrix: ArrayLike) -> NDArray[np.float64]:
4950
"""Calculate covariance matrix."""
50-
return np.cov(matrix, rowvar=0)
51+
matrix_array: NDArray[np.float64] = np.asarray(matrix, dtype=np.float64)
52+
return np.cov(matrix_array, rowvar=False)
5153

5254

5355
def denoise(gradients: list, scale: float, q: float):

examples/gradient_leakage_attacks/defense/GradDefense/dataloader.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,41 @@
11
"""
2-
Dataloader of GradDefense
2+
Dataloader of GradDefense.
33
44
Reference:
55
Wang et al., "Protect Privacy from Gradient Leakage Attack in Federated Learning," INFOCOM 2022.
66
https://github.com/wangjunxiao/GradDefense
77
"""
88

9+
from __future__ import annotations
10+
11+
from typing import Any, Protocol, Sequence, cast
12+
913
import numpy as np
1014
from torch.utils.data import Subset
1115
from torch.utils.data.dataloader import DataLoader
1216
from torch.utils.data.dataset import Dataset
1317

18+
19+
class LabeledDataset(Protocol):
20+
"""Minimal protocol for datasets with labeled samples."""
21+
22+
classes: Sequence[Any]
23+
24+
def __len__(self) -> int: ...
25+
26+
def __getitem__(self, index: int) -> tuple[Any, Any]: ...
27+
28+
1429
DEFAULT_NUM_WORKERS = 8
1530
ROOTSET_PER_CLASS = 5
1631
ROOTSET_SIZE = 50
1732

1833

1934
def extract_root_set(
20-
dataset: Dataset,
35+
dataset: LabeledDataset,
2136
sample_per_class: int = ROOTSET_PER_CLASS,
22-
seed: int = None,
23-
):
37+
seed: int | None = None,
38+
) -> tuple[list[int], dict[int, list[int]]]:
2439
"""Extract root dataset."""
2540
num_classes = len(dataset.classes)
2641
class2sample = {i: [] for i in range(num_classes)}
@@ -39,10 +54,10 @@ def extract_root_set(
3954
return select_indices, class2sample
4055

4156

42-
def get_root_set_loader(trainset):
57+
def get_root_set_loader(trainset: LabeledDataset) -> DataLoader[Any]:
4358
"""Obtain root dataset loader."""
4459
rootset_indices, __ = extract_root_set(trainset)
45-
root_set = Subset(trainset, rootset_indices)
60+
root_set = Subset(cast(Dataset[Any], trainset), rootset_indices)
4661
root_dataloader = DataLoader(
4762
root_set, batch_size=len(root_set), num_workers=DEFAULT_NUM_WORKERS
4863
)

examples/gradient_leakage_attacks/defense/GradDefense/sensitivity.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,38 @@
11
"""
2-
Sensitivity computation of GradDefense
2+
Sensitivity computation of GradDefense.
33
44
Reference:
55
Wang et al., "Protect Privacy from Gradient Leakage Attack in Federated Learning," INFOCOM 2022.
66
https://github.com/wangjunxiao/GradDefense
77
"""
88

9+
from __future__ import annotations
10+
911
import torch
1012
import torch.nn as nn
1113
from torch.utils.data.dataloader import DataLoader
1214

1315

1416
def compute_sens(
1517
model: nn.Module,
16-
rootset_loader: DataLoader,
18+
rootset_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
1719
device: torch.device,
18-
loss_fn=nn.CrossEntropyLoss(),
19-
):
20+
loss_fn: nn.Module | None = None,
21+
) -> list[float]:
2022
"""Compute sensitivity."""
2123
x, y = next(iter(rootset_loader))
2224

2325
x = x.to(device).requires_grad_()
2426
y = y.to(device)
2527
model = model.to(device)
2628

29+
if loss_fn is None:
30+
loss_fn = nn.CrossEntropyLoss()
31+
2732
# Compute prediction and loss
28-
try:
29-
pred, _ = model(x)
30-
except:
31-
pred = model(x)
33+
pred = model(x)
34+
if isinstance(pred, tuple):
35+
pred = pred[0]
3236

3337
loss = loss_fn(pred, y)
3438
# Backward propagation
@@ -49,11 +53,11 @@ def compute_sens(
4953

5054
sensitivity = []
5155
for layer_vjp in vector_jacobian_products:
52-
f_norm_sum = 0
56+
sum_norms = torch.zeros((), device=layer_vjp.device)
5357
for sample_vjp in layer_vjp:
5458
# Sample-wise Frobenius norm
55-
f_norm_sum += torch.norm(sample_vjp)
56-
f_norm = f_norm_sum / len(layer_vjp)
57-
sensitivity.append(f_norm.cpu().numpy())
59+
sum_norms = sum_norms + torch.norm(sample_vjp)
60+
f_norm = sum_norms / layer_vjp.shape[0]
61+
sensitivity.append(float(f_norm.detach().cpu()))
5862

5963
return sensitivity
Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,32 @@
1-
"""Obtaining models adapted from existing work's implementations.
1+
"""
2+
Obtaining models adapted from existing work's implementations.
23
34
An extra return object named `feature` is added in each model's forward function,
45
which will be used in the defense Soteria.
56
"""
67

7-
from typing import Union
8+
from __future__ import annotations
9+
10+
from typing import Callable, Optional
811

9-
from nn import (
10-
lenet,
11-
resnet,
12-
)
12+
import torch.nn as nn
13+
from nn import lenet, resnet
1314

1415
from plato.config import Config
1516

1617

17-
def get(**kwargs: str | dict):
18-
"""Get the model with the provided name."""
19-
model_name = (
20-
kwargs["model_name"] if "model_name" in kwargs else Config().trainer.model_name
21-
)
18+
def get(model_name: str | None = None) -> Optional[Callable[[], nn.Module]]:
19+
"""Get the model constructor with the provided name."""
20+
resolved_name = model_name or Config().trainer.model_name
2221

23-
if model_name == "lenet":
22+
if resolved_name == "lenet":
2423
return lenet.Model
2524

26-
if model_name.split("_")[0] == "resnet":
27-
return resnet.get(model_name=model_name)
25+
if resolved_name.startswith("resnet_"):
26+
return resnet.get(model_name=resolved_name)
2827

2928
# Set up model through plato's model library
3029
if Config().trainer.model_type == "vit":
3130
return None
3231

33-
raise ValueError(f"No such model: {model_name}")
32+
raise ValueError(f"No such model: {resolved_name}")

0 commit comments

Comments
 (0)