Skip to content

Commit 2178c82

Browse files
committed
examples/secure_aggregation and examples/split_learning passed ty checks.
1 parent 14ce22a commit 2178c82

File tree

13 files changed

+102
-55
lines changed

13 files changed

+102
-55
lines changed

docs/docs/index.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ Welcome to *Plato*, a software framework to facilitate scalable, reproducible, a
99

1010
- **[Installation](install.md)** - Installing Plato and setting up your development environment
1111
- **[Quick Start](quickstart.md)** - Getting started with Plato
12-
- Plato supports both PyTorch and MLX backends (MLX for Apple Silicon devices)
1312

1413
## Examples
1514

docs/docs/quickstart.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ To fix all linter errors automatically, run:
157157
uvx ruff check --fix
158158
```
159159

160-
# Type Checking
160+
## Type Checking
161161

162162
It is also strongly recommended that new additions and revisions of the code base to pass Astral's [ty](https://docs.astral.sh/ty/) type checker cleanly. To install `ty` globally using `uv`, run:
163163

examples/async/fedbuff/fedbuff_cifar10.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ speed_simulation = true
1717
max_sleep_time = 30
1818

1919
# Should clients really go to sleep, or should we just simulate the sleep times?
20-
sleep_simulation = false
20+
sleep_simulation = true
2121

2222
# If we are simulating client training times, what is the average training time?
2323
avg_training_time = 20

examples/model_pruning/fedscr/fedscr_server.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import os
7-
from typing import Dict, Optional, TYPE_CHECKING, cast
7+
from typing import TYPE_CHECKING, Dict, Optional, cast
88

99
import numpy as np
1010

@@ -58,11 +58,7 @@ def __init__(
5858
def customize_server_response(self, server_response: dict, client_id) -> dict:
5959
"""Wraps up generating the server response with any additional information."""
6060
trainer = cast(Optional["FedSCRTrainer"], self.trainer)
61-
if (
62-
trainer is not None
63-
and trainer.use_adaptive
64-
and self.current_round > 1
65-
):
61+
if trainer is not None and trainer.use_adaptive and self.current_round > 1:
6662
self.calc_threshold()
6763
server_response["update_thresholds"] = self.update_thresholds
6864
return server_response
@@ -101,9 +97,7 @@ def weights_aggregated(self, updates):
10197
float(np.var([update.report.loss for update in updates]))
10298
)
10399
if self.current_round > 3:
104-
self.mean_variance = sum(self.loss_variances) / (
105-
self.current_round - 2
106-
)
100+
self.mean_variance = sum(self.loss_variances) / (self.current_round - 2)
107101
else:
108102
self.mean_variance = 0.0
109103

examples/secure_aggregation/maskcrypt/maskcrypt_client.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,18 @@ def setup(self, context: ClientContext) -> None:
5050
def load_payload(self, context: ClientContext, server_payload: Any) -> None:
5151
"""Store inbound payload or delegate weight loading based on the round."""
5252
state = self._state(context)
53+
owner = getattr(context, "owner", None)
5354

5455
if context.current_round % 2 != 0:
5556
state["final_mask"] = None
56-
context.owner.final_mask = None
57+
if owner is not None and hasattr(owner, "final_mask"):
58+
setattr(owner, "final_mask", None)
5759
super().load_payload(context, server_payload)
5860
return
5961

6062
state["final_mask"] = server_payload
61-
context.owner.final_mask = server_payload
63+
if owner is not None and hasattr(owner, "final_mask"):
64+
setattr(owner, "final_mask", server_payload)
6265

6366
async def train(self, context: ClientContext) -> tuple[Any, Any]:
6467
"""Alternate between mask proposal computation and weight submission."""
@@ -136,12 +139,17 @@ def _compute_mask(
136139
class MaskCryptClientProxy(simple.Client):
137140
"""Client variant exposing MaskCrypt state via a convenient property."""
138141

142+
encrypt_ratio: float
143+
random_mask: bool
144+
attack_prep_dir: str
145+
checkpoint_path: str
146+
139147
@property
140-
def final_mask(self):
148+
def final_mask(self) -> Any | None:
141149
return self._context.state.get("maskcrypt", {}).get("final_mask")
142150

143151
@final_mask.setter
144-
def final_mask(self, value):
152+
def final_mask(self, value: Any | None) -> None:
145153
self._context.state.setdefault("maskcrypt", {})["final_mask"] = value
146154

147155

examples/secure_aggregation/maskcrypt/maskcrypt_server.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
A MaskCrypt server with selective homomorphic encryption support.
33
"""
44

5+
from typing import cast
6+
57
from maskcrypt_algorithm import Algorithm as MaskCryptAlgorithm
68

79
from plato.servers import fedavg_he
@@ -50,14 +52,21 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received):
5052

5153
return baseline_weights
5254
else:
53-
# Clients send model updates in even rounds, conduct aggregation
54-
aggregated_weights = await super().aggregate_weights(
55-
updates, baseline_weights, weights_received
55+
strategy = getattr(self, "aggregation_strategy", None)
56+
if strategy is None or not hasattr(strategy, "aggregate_weights"):
57+
raise AttributeError(
58+
"Aggregation strategy must expose an 'aggregate_weights' coroutine."
59+
)
60+
aggregated_weights = await strategy.aggregate_weights(
61+
updates, baseline_weights, weights_received, self.context
5662
)
63+
if aggregated_weights is None:
64+
raise RuntimeError("Aggregation strategy failed to produce weights.")
5765

5866
return aggregated_weights
5967

6068
def _mask_consensus(self, updates):
6169
"""Conduct mask consensus on the reported mask proposals."""
6270
proposals = [update.payload for update in updates]
63-
self.final_mask = self.algorithm.build_consensus_mask(proposals)
71+
algorithm = cast(MaskCryptAlgorithm, self.require_algorithm())
72+
self.final_mask = algorithm.build_consensus_mask(proposals)

examples/split_learning/llm_split_learning/split_learning_llm_model.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""
2-
Obtain LLM models from HuggingFace, specifically designed for split learning
2+
Obtain LLM models from HuggingFace, specifically designed for split learning.
33
"""
44

5+
from typing import Any, Union, cast
6+
57
import torch
68
from peft import LoraConfig, get_peft_model
79
from transformers import AutoConfig, AutoModelForCausalLM
@@ -17,7 +19,7 @@ def get_lora_model(model):
1719
return model
1820

1921

20-
def get_module(start_module: torch.nn.Module, module_names):
22+
def get_module(start_module: torch.nn.Module, module_names) -> torch.nn.Module:
2123
"""
2224
Recursively get a PyTorch module starting from the start module with
2325
a given list of module names.
@@ -28,6 +30,9 @@ def get_module(start_module: torch.nn.Module, module_names):
2830
return module
2931

3032

33+
TransformerSequence = Union[torch.nn.Sequential, torch.nn.ModuleList]
34+
35+
3136
class BaseModel(torch.nn.Module):
3237
"""
3338
The basic model loading HuggingFace model used for the server model and the client model
@@ -47,14 +52,16 @@ def __init__(self, *args, **kwargs) -> None:
4752

4853
self.config = AutoConfig.from_pretrained(self.model_name, **config_kwargs)
4954

50-
self.base_model = AutoModelForCausalLM.from_pretrained(
55+
base_model = AutoModelForCausalLM.from_pretrained(
5156
self.model_name,
5257
config=self.config,
5358
cache_dir=Config().params["model_path"] + "/huggingface",
5459
token=use_auth_token,
5560
)
56-
if hasattr(self.base_model, "loss_type"):
57-
self.base_model.loss_type = "ForCausalLM"
61+
base_model_for_loss = cast(Any, base_model)
62+
if hasattr(base_model_for_loss, "loss_type"):
63+
base_model_for_loss.loss_type = "ForCausalLM"
64+
self.base_model = base_model
5865
self.cut_layer = Config().parameters.model.cut_layer
5966

6067
def get_input_embeddings(self):
@@ -79,9 +86,11 @@ def __init__(self, *args, **kwargs) -> None:
7986
super().__init__(*args, **kwargs)
8087
# replace the layers in the base model
8188
# which should be on the cloud with Identity layers()
82-
transformer_module = self.base_model
83-
for module_name in Config().parameters.model.transformer_module_name.split("."):
84-
transformer_module = getattr(transformer_module, module_name)
89+
transformer_module_raw = get_module(
90+
self.base_model,
91+
Config().parameters.model.transformer_module_name.split("."),
92+
)
93+
transformer_module = cast(TransformerSequence, transformer_module_raw)
8594
client_layers = transformer_module[: self.cut_layer]
8695
client_module_names = Config().parameters.model.transformer_module_name.split(
8796
"."
@@ -126,18 +135,21 @@ def __init__(self, *args, **kwargs) -> None:
126135
# The first copy of the model is the whole model which is used for test.
127136
# The second copy of the model only contains the layers on the server
128137
# used for training.
129-
self.server_model = AutoModelForCausalLM.from_pretrained(
138+
server_model = AutoModelForCausalLM.from_pretrained(
130139
self.model_name,
131140
config=self.config,
132141
cache_dir=Config().params["model_path"] + "/huggingface",
133142
)
134-
if hasattr(self.server_model, "loss_type"):
135-
self.server_model.loss_type = "ForCausalLM"
143+
server_model_for_loss = cast(Any, server_model)
144+
if hasattr(server_model_for_loss, "loss_type"):
145+
server_model_for_loss.loss_type = "ForCausalLM"
146+
self.server_model = server_model
136147
transformer_module = get_module(
137148
self.base_model,
138149
Config().parameters.model.transformer_module_name.split("."),
139150
)
140-
server_layers = transformer_module[self.cut_layer :]
151+
transformer_sequence = cast(TransformerSequence, transformer_module)
152+
server_layers = transformer_sequence[self.cut_layer :]
141153
server_module_names = Config().parameters.model.transformer_module_name.split(
142154
"."
143155
)
@@ -159,9 +171,8 @@ def copy_weight(self):
159171
base_model_weights = self.base_model.state_dict()
160172
server_model_weights = self.server_model.state_dict()
161173

162-
transformer_module = self.base_model
163-
for module_name in basic_name.split("."):
164-
transformer_module = getattr(transformer_module, module_name)
174+
transformer_module_raw = get_module(self.base_model, basic_name.split("."))
175+
transformer_module = cast(TransformerSequence, transformer_module_raw)
165176
layer_names = [
166177
basic_name + "." + str(index)
167178
for index in range(

examples/split_learning/llm_split_learning/split_learning_lora_algorithm.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
A split learning algorithm supporting LoRA fine-tuning LLMs.
33
"""
44

5+
from typing import Dict, cast
6+
57
from peft import (
68
get_peft_model_state_dict,
79
set_peft_model_state_dict,
810
)
11+
from torch import Tensor
12+
from torch.nn import Module
913

1014
from plato.algorithms import split_learning
1115

@@ -15,13 +19,22 @@ class Algorithm(split_learning.Algorithm):
1519
Extract and load only the LoRA weights.
1620
"""
1721

18-
def extract_weights(self, model=None):
19-
# Extract LoRA wegiths
20-
return {
21-
k: v.cpu()
22-
for k, v in get_peft_model_state_dict(self.model.base_model).items()
23-
}
22+
def _get_base_model(self, model: object | None = None) -> Module:
23+
"""Return the wrapped HuggingFace base model."""
24+
model_obj = model if model is not None else self.model
25+
if model_obj is None or not hasattr(model_obj, "base_model"):
26+
raise AttributeError(
27+
"LoRA split learning requires a model with a `base_model` attribute."
28+
)
29+
base_model = getattr(model_obj, "base_model")
30+
return cast(Module, base_model)
31+
32+
def extract_weights(self, model=None) -> Dict[str, Tensor]:
33+
"""Extract LoRA weights from the underlying base model."""
34+
base_model = self._get_base_model(model)
35+
return {k: v.cpu() for k, v in get_peft_model_state_dict(base_model).items()}
2436

25-
def load_weights(self, weights):
26-
# Load LoRA weights
27-
return set_peft_model_state_dict(self.model.base_model, weights)
37+
def load_weights(self, weights: Dict[str, Tensor]):
38+
"""Load LoRA weights into the underlying base model."""
39+
base_model = self._get_base_model()
40+
return set_peft_model_state_dict(base_model, weights)

examples/split_learning/llm_split_learning/split_learning_server_attack.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,15 @@ def attack(self, update):
4242
"""
4343
self.attack_started = True
4444
intermediate_features, labels = update[0]
45-
evaluation_metrics = self.trainer.attack(intermediate_features, labels)
45+
trainer = self.trainer
46+
if trainer is None or not hasattr(trainer, "attack"):
47+
raise AttributeError(
48+
"Trainer must define an `attack` method for curious server attacks."
49+
)
50+
attack_fn = getattr(trainer, "attack")
51+
if not callable(attack_fn):
52+
raise TypeError("Trainer attack must be callable.")
53+
evaluation_metrics = attack_fn(intermediate_features, labels)
4654
rouge_metrics = evaluation_metrics["ROUGE"]
4755
self.rouge["rouge1_fm"] = rouge_metrics["rouge1_fmeasure"].item()
4856
self.rouge["rouge1_p"] = rouge_metrics["rouge1_precision"].item()

examples/split_learning/llm_split_learning/split_learning_trainer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from collections import OrderedDict
9-
from typing import Optional
9+
from typing import Any, Optional, Sized, cast
1010

1111
import evaluate
1212
from torch import Tensor, reshape
@@ -40,13 +40,14 @@ def preprocess_logits_for_metrics(logits, labels):
4040

4141
def compute_metrics(eval_preds):
4242
"""Calculate the accuracy for evaluation stage."""
43-
metric = evaluate.load("accuracy")
43+
metric: Any = evaluate.load("accuracy")
4444
preds, labels = eval_preds
4545
# preds have the same shape as the labels, after the argmax(-1) has been calculated
4646
# by preprocess_logits_for_metrics but we need to shift the labels
4747
labels = labels.reshape(-1)
4848
preds = preds.reshape(-1)
49-
return metric.compute(predictions=preds, references=labels)
49+
compute_fn = getattr(metric, "compute")
50+
return compute_fn(predictions=preds, references=labels)
5051

5152

5253
# ============================================================================
@@ -88,7 +89,10 @@ def __init__(
8889
def _get_train_sampler(self) -> Sampler | None:
8990
"""Get training sampler."""
9091
if self.sampler is None:
91-
return RandomSampler(self.train_dataset)
92+
if self.train_dataset is None:
93+
raise ValueError("Training dataset is not initialized.")
94+
dataset = cast(Sized, self.train_dataset)
95+
return RandomSampler(dataset)
9296
return self.sampler
9397

9498
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
@@ -334,6 +338,8 @@ def server_forward_from(self, batch, config):
334338
inputs, labels = batch
335339
batch_size = inputs.size(0)
336340
inputs = inputs.detach().requires_grad_(True)
341+
if self.model is None or not hasattr(self.model, "forward_from"):
342+
raise AttributeError("Model must provide a `forward_from` method.")
337343
outputs = self.model.forward_from(inputs, labels)
338344
loss = outputs.loss
339345
loss.backward()

0 commit comments

Comments
 (0)