Skip to content

Commit 15923a7

Browse files
committed
Fixed an issue where FedALA can go on forever in some rounds.
1 parent 03fd994 commit 15923a7

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

plato/trainers/strategies/algorithms/fedala_strategy.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
batch_size: int | None = None,
7373
threshold: float = 0.1,
7474
num_pre_loss: int = 10,
75+
max_ala_epochs: int | None = 20,
7576
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
7677
save_state: bool = True,
7778
) -> None:
@@ -89,6 +90,10 @@ def __init__(
8990
raise ValueError(
9091
f"num_pre_loss must be >= 1, got {num_pre_loss}"
9192
)
93+
if max_ala_epochs is not None and max_ala_epochs < 1:
94+
raise ValueError(
95+
f"max_ala_epochs must be >= 1 or None, got {max_ala_epochs}"
96+
)
9297
if batch_size is not None and batch_size < 1:
9398
raise ValueError(f"batch_size must be >= 1, got {batch_size}")
9499

@@ -98,6 +103,9 @@ def __init__(
98103
self.batch_size = batch_size
99104
self.threshold = float(threshold)
100105
self.num_pre_loss = int(num_pre_loss)
106+
self.max_ala_epochs = (
107+
int(max_ala_epochs) if max_ala_epochs is not None else None
108+
)
101109
self.loss_fn = loss_fn or nn.CrossEntropyLoss()
102110
self.save_state = save_state
103111

@@ -469,6 +477,13 @@ def _adaptive_local_aggregation(
469477

470478
losses.append(float(loss_value.item()))
471479
cnt += 1
480+
if self.max_ala_epochs is not None and cnt >= self.max_ala_epochs:
481+
LOGGER.info(
482+
"[Client #%d] FedALA reached max_ala_epochs=%d; stopping ALA.",
483+
context.client_id,
484+
self.max_ala_epochs,
485+
)
486+
break
472487

473488
if not self.start_phase:
474489
break
@@ -525,6 +540,7 @@ class FedALAUpdateStrategyFromConfig(FedALAUpdateStrategy):
525540
- layer_idx (default: 0)
526541
- threshold (default: 0.1)
527542
- num_pre_loss (default: 10)
543+
- max_ala_epochs (default: 20)
528544
- ala_batch_size (optional)
529545
"""
530546

@@ -545,6 +561,16 @@ def __init__(self) -> None:
545561
num_pre_loss = self._get_config_value(
546562
algo, ["num_pre_loss", "fedala_num_pre_loss", "ala_num_pre_loss"], 10
547563
)
564+
max_ala_epochs = self._get_config_value(
565+
algo,
566+
[
567+
"max_ala_epochs",
568+
"fedala_max_ala_epochs",
569+
"ala_max_epochs",
570+
"fedala_max_epochs",
571+
],
572+
20,
573+
)
548574
batch_size = self._get_config_value(
549575
algo, ["ala_batch_size", "fedala_batch_size"], None
550576
)
@@ -559,6 +585,7 @@ def __init__(self) -> None:
559585
batch_size=batch_size,
560586
threshold=float(threshold),
561587
num_pre_loss=int(num_pre_loss),
588+
max_ala_epochs=None if max_ala_epochs is None else int(max_ala_epochs),
562589
)
563590

564591
@staticmethod

0 commit comments

Comments
 (0)