Skip to content

Commit 8f29fd3

Browse files
authored
Completed migrating clients to the new strategy API. (#395)
* Migrated examples/customized/custom_client.py to the new composable client API. * ruff format & ruff check --fix. * Migrated some examples to use the latest client API using strategies. * Expanded the test coverage for registry-style components across models, datasources, and samplers. * Migrated all remaining examples to the new client API.
1 parent d66614e commit 8f29fd3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1248
-929
lines changed

.github/workflows/delete_workflow_runs.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# This workflow will delete old workflow runs, on a schedule or manually triggered.
22
name: Delete old workflow runs
33
permissions:
4-
contents: read
5-
pull-requests: read
4+
contents: read/write
65

76
on:
87
workflow_dispatch:

examples/basic/basic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from torchvision.transforms import ToTensor
1515

1616
from plato.clients import simple
17-
from plato.datasources import base
1817
from plato.config import Config
18+
from plato.datasources import base
1919
from plato.servers import fedavg
2020
from plato.trainers.composable import ComposableTrainer
2121
from plato.trainers.strategies.base import (
@@ -38,8 +38,12 @@ def __init__(self):
3838
Config()
3939
base_path = Path(Config.params.get("base_path", "./runtime"))
4040
data_dir = Path(Config.params.get("data_path", base_path / "data"))
41-
self.trainset = MNIST(str(data_dir), train=True, download=True, transform=ToTensor())
42-
self.testset = MNIST(str(data_dir), train=False, download=True, transform=ToTensor())
41+
self.trainset = MNIST(
42+
str(data_dir), train=True, download=True, transform=ToTensor()
43+
)
44+
self.testset = MNIST(
45+
str(data_dir), train=False, download=True, transform=ToTensor()
46+
)
4347

4448

4549
class MNISTTrainingStepStrategy(TrainingStepStrategy):

examples/basic/server_strategies.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ def __init__(self):
6262
Config()
6363
base_path = Path(Config.params.get("base_path", "./runtime"))
6464
data_dir = Path(Config.params.get("data_path", base_path / "data"))
65-
self.trainset = MNIST(str(data_dir), train=True, download=True, transform=ToTensor())
66-
self.testset = MNIST(str(data_dir), train=False, download=True, transform=ToTensor())
65+
self.trainset = MNIST(
66+
str(data_dir), train=True, download=True, transform=ToTensor()
67+
)
68+
self.testset = MNIST(
69+
str(data_dir), train=False, download=True, transform=ToTensor()
70+
)
6771

6872

6973
class MNISTTrainingStepStrategy(TrainingStepStrategy):

examples/client_selection/afl/afl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
def main():
1818
"""A Plato federated learning training session using the AFL algorithm."""
19-
client = afl_client.Client()
19+
client = afl_client.create_client()
2020
server = afl_server.Server()
2121
server.run(client)
2222

examples/client_selection/afl/afl_client.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
import logging
1414
import math
15-
from types import SimpleNamespace
16-
from typing import Iterable, Optional
15+
from typing import Iterable, List, Optional
1716

1817
import torch
1918

@@ -164,38 +163,51 @@ def _get_pre_training_loss(context: ClientContext) -> Optional[float]:
164163
return 0.0
165164

166165

167-
class Client(simple.Client):
168-
"""A federated learning client for AFL."""
169-
170-
def __init__(
171-
self,
172-
model=None,
173-
datasource=None,
174-
algorithm=None,
175-
trainer=None,
176-
callbacks=None,
177-
trainer_callbacks: Optional[Iterable] = None,
166+
def _ensure_pretraining_callback(
167+
trainer_callbacks: Optional[Iterable],
168+
) -> List:
169+
"""Ensure AFL's pre-training loss callback is present once."""
170+
callbacks_list = list(trainer_callbacks) if trainer_callbacks else []
171+
if not any(
172+
cb == AFLPreTrainingLossCallback
173+
or getattr(cb, "__class__", None) == AFLPreTrainingLossCallback
174+
for cb in callbacks_list
178175
):
179-
callbacks_list = list(trainer_callbacks) if trainer_callbacks else []
180-
if not any(
181-
cb == AFLPreTrainingLossCallback
182-
or getattr(cb, "__class__", None) == AFLPreTrainingLossCallback
183-
for cb in callbacks_list
184-
):
185-
callbacks_list.append(AFLPreTrainingLossCallback)
186-
187-
super().__init__(
188-
model=model,
189-
datasource=datasource,
190-
algorithm=algorithm,
191-
trainer=trainer,
192-
callbacks=callbacks,
193-
trainer_callbacks=callbacks_list,
194-
)
195-
self._configure_composable(
196-
lifecycle_strategy=self.lifecycle_strategy,
197-
payload_strategy=self.payload_strategy,
198-
training_strategy=self.training_strategy,
199-
reporting_strategy=AFLReportingStrategy(),
200-
communication_strategy=self.communication_strategy,
201-
)
176+
callbacks_list.append(AFLPreTrainingLossCallback)
177+
return callbacks_list
178+
179+
180+
def create_client(
181+
*,
182+
model=None,
183+
datasource=None,
184+
algorithm=None,
185+
trainer=None,
186+
callbacks=None,
187+
trainer_callbacks: Optional[Iterable] = None,
188+
):
189+
"""Build an AFL client configured with valuation hooks."""
190+
callbacks_list = _ensure_pretraining_callback(trainer_callbacks)
191+
192+
client = simple.Client(
193+
model=model,
194+
datasource=datasource,
195+
algorithm=algorithm,
196+
trainer=trainer,
197+
callbacks=callbacks,
198+
trainer_callbacks=callbacks_list,
199+
)
200+
201+
client._configure_composable(
202+
lifecycle_strategy=client.lifecycle_strategy,
203+
payload_strategy=client.payload_strategy,
204+
training_strategy=client.training_strategy,
205+
reporting_strategy=AFLReportingStrategy(),
206+
communication_strategy=client.communication_strategy,
207+
)
208+
209+
return client
210+
211+
212+
# Maintain compatibility for previous imports that expected a Client callable.
213+
Client = create_client

examples/client_selection/afl/afl_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
def main():
1818
"""A Plato federated learning training session using AFL strategy."""
19-
client = afl_client.Client()
19+
client = afl_client.create_client()
2020
server = afl_server_strategy.Server()
2121
server.run(client)
2222

examples/client_selection/oort/oort.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
def main():
1919
"""A Plato federated learning training session using Oort strategy."""
2020
trainer = oort_trainer.Trainer
21-
client = oort_client.Client(trainer=trainer)
21+
client = oort_client.create_client(trainer=trainer)
2222
server = oort_server.Server(trainer=trainer)
2323
server.run(client)
2424

examples/client_selection/oort/oort_client.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,31 +39,35 @@ def build_report(self, context: ClientContext, report):
3939
return report
4040

4141

42-
class Client(simple.Client):
43-
"""A federated learning client that calculates its statistical utility."""
44-
45-
def __init__(
46-
self,
47-
model=None,
48-
datasource=None,
49-
algorithm=None,
50-
trainer=None,
51-
callbacks=None,
52-
trainer_callbacks=None,
53-
):
54-
super().__init__(
55-
model=model,
56-
datasource=datasource,
57-
algorithm=algorithm,
58-
trainer=trainer,
59-
callbacks=callbacks,
60-
trainer_callbacks=trainer_callbacks,
61-
)
42+
def create_client(
43+
*,
44+
model=None,
45+
datasource=None,
46+
algorithm=None,
47+
trainer=None,
48+
callbacks=None,
49+
trainer_callbacks=None,
50+
):
51+
"""Build an Oort client configured with statistical utility reporting."""
52+
client = simple.Client(
53+
model=model,
54+
datasource=datasource,
55+
algorithm=algorithm,
56+
trainer=trainer,
57+
callbacks=callbacks,
58+
trainer_callbacks=trainer_callbacks,
59+
)
6260

63-
self._configure_composable(
64-
lifecycle_strategy=self.lifecycle_strategy,
65-
payload_strategy=self.payload_strategy,
66-
training_strategy=self.training_strategy,
67-
reporting_strategy=OortReportingStrategy(),
68-
communication_strategy=self.communication_strategy,
69-
)
61+
client._configure_composable(
62+
lifecycle_strategy=client.lifecycle_strategy,
63+
payload_strategy=client.payload_strategy,
64+
training_strategy=client.training_strategy,
65+
reporting_strategy=OortReportingStrategy(),
66+
communication_strategy=client.communication_strategy,
67+
)
68+
69+
return client
70+
71+
72+
# Maintain compatibility for imports expecting a Client callable.
73+
Client = create_client

examples/client_selection/pisces/pisces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
def main():
1818
"""Pisces: an asynchronous client selection and server aggregation algorithm."""
1919
trainer = pisces_trainer.Trainer
20-
client = pisces_client.Client(trainer=trainer)
20+
client = pisces_client.create_client(trainer=trainer)
2121
server = pisces_server.Server(trainer=trainer)
2222

2323
server.run(client)

examples/client_selection/pisces/pisces_client.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -59,35 +59,38 @@ def build_report(self, context: ClientContext, report):
5959
return report
6060

6161

62-
class Client(simple.Client):
62+
def create_client(
63+
*,
64+
model=None,
65+
datasource=None,
66+
algorithm=None,
67+
trainer=None,
68+
callbacks=None,
69+
trainer_callbacks=None,
70+
loss_decay: float = 1e-2,
71+
):
6372
"""
64-
A Pisces federated learning client who sends weight updates and client statistical utility.
73+
Build a Pisces client that reports statistical utility with an EMA of squared loss.
6574
"""
66-
67-
def __init__(
68-
self,
69-
model=None,
70-
datasource=None,
71-
algorithm=None,
72-
trainer=None,
73-
callbacks=None,
74-
trainer_callbacks=None,
75-
*,
76-
loss_decay: float = 1e-2,
77-
):
78-
super().__init__(
79-
model=model,
80-
datasource=datasource,
81-
algorithm=algorithm,
82-
trainer=trainer,
83-
callbacks=callbacks,
84-
trainer_callbacks=trainer_callbacks,
85-
)
86-
87-
self._configure_composable(
88-
lifecycle_strategy=self.lifecycle_strategy,
89-
payload_strategy=self.payload_strategy,
90-
training_strategy=self.training_strategy,
91-
reporting_strategy=PiscesReportingStrategy(loss_decay),
92-
communication_strategy=self.communication_strategy,
93-
)
75+
client = simple.Client(
76+
model=model,
77+
datasource=datasource,
78+
algorithm=algorithm,
79+
trainer=trainer,
80+
callbacks=callbacks,
81+
trainer_callbacks=trainer_callbacks,
82+
)
83+
84+
client._configure_composable(
85+
lifecycle_strategy=client.lifecycle_strategy,
86+
payload_strategy=client.payload_strategy,
87+
training_strategy=client.training_strategy,
88+
reporting_strategy=PiscesReportingStrategy(loss_decay),
89+
communication_strategy=client.communication_strategy,
90+
)
91+
92+
return client
93+
94+
95+
# Maintain compatibility for imports expecting a Client callable.
96+
Client = create_client

0 commit comments

Comments
 (0)