|
12 | 12 |
|
13 | 13 | import logging |
14 | 14 | import math |
15 | | -from types import SimpleNamespace |
16 | | -from typing import Iterable, Optional |
| 15 | +from typing import Iterable, List, Optional |
17 | 16 |
|
18 | 17 | import torch |
19 | 18 |
|
@@ -164,38 +163,51 @@ def _get_pre_training_loss(context: ClientContext) -> Optional[float]: |
164 | 163 | return 0.0 |
165 | 164 |
|
166 | 165 |
|
167 | | -class Client(simple.Client): |
168 | | - """A federated learning client for AFL.""" |
169 | | - |
170 | | - def __init__( |
171 | | - self, |
172 | | - model=None, |
173 | | - datasource=None, |
174 | | - algorithm=None, |
175 | | - trainer=None, |
176 | | - callbacks=None, |
177 | | - trainer_callbacks: Optional[Iterable] = None, |
| 166 | +def _ensure_pretraining_callback( |
| 167 | + trainer_callbacks: Optional[Iterable], |
| 168 | +) -> List: |
| 169 | + """Ensure AFL's pre-training loss callback is present once.""" |
| 170 | + callbacks_list = list(trainer_callbacks) if trainer_callbacks else [] |
| 171 | + if not any( |
| 172 | + cb == AFLPreTrainingLossCallback |
| 173 | + or getattr(cb, "__class__", None) == AFLPreTrainingLossCallback |
| 174 | + for cb in callbacks_list |
178 | 175 | ): |
179 | | - callbacks_list = list(trainer_callbacks) if trainer_callbacks else [] |
180 | | - if not any( |
181 | | - cb == AFLPreTrainingLossCallback |
182 | | - or getattr(cb, "__class__", None) == AFLPreTrainingLossCallback |
183 | | - for cb in callbacks_list |
184 | | - ): |
185 | | - callbacks_list.append(AFLPreTrainingLossCallback) |
186 | | - |
187 | | - super().__init__( |
188 | | - model=model, |
189 | | - datasource=datasource, |
190 | | - algorithm=algorithm, |
191 | | - trainer=trainer, |
192 | | - callbacks=callbacks, |
193 | | - trainer_callbacks=callbacks_list, |
194 | | - ) |
195 | | - self._configure_composable( |
196 | | - lifecycle_strategy=self.lifecycle_strategy, |
197 | | - payload_strategy=self.payload_strategy, |
198 | | - training_strategy=self.training_strategy, |
199 | | - reporting_strategy=AFLReportingStrategy(), |
200 | | - communication_strategy=self.communication_strategy, |
201 | | - ) |
| 176 | + callbacks_list.append(AFLPreTrainingLossCallback) |
| 177 | + return callbacks_list |
| 178 | + |
| 179 | + |
| 180 | +def create_client( |
| 181 | + *, |
| 182 | + model=None, |
| 183 | + datasource=None, |
| 184 | + algorithm=None, |
| 185 | + trainer=None, |
| 186 | + callbacks=None, |
| 187 | + trainer_callbacks: Optional[Iterable] = None, |
| 188 | +): |
| 189 | + """Build an AFL client configured with valuation hooks.""" |
| 190 | + callbacks_list = _ensure_pretraining_callback(trainer_callbacks) |
| 191 | + |
| 192 | + client = simple.Client( |
| 193 | + model=model, |
| 194 | + datasource=datasource, |
| 195 | + algorithm=algorithm, |
| 196 | + trainer=trainer, |
| 197 | + callbacks=callbacks, |
| 198 | + trainer_callbacks=callbacks_list, |
| 199 | + ) |
| 200 | + |
| 201 | + client._configure_composable( |
| 202 | + lifecycle_strategy=client.lifecycle_strategy, |
| 203 | + payload_strategy=client.payload_strategy, |
| 204 | + training_strategy=client.training_strategy, |
| 205 | + reporting_strategy=AFLReportingStrategy(), |
| 206 | + communication_strategy=client.communication_strategy, |
| 207 | + ) |
| 208 | + |
| 209 | + return client |
| 210 | + |
| 211 | + |
| 212 | +# Maintain compatibility for previous imports that expected a Client callable. |
| 213 | +Client = create_client |
0 commit comments