@@ -72,6 +72,7 @@ def __init__(
7272 batch_size : int | None = None ,
7373 threshold : float = 0.1 ,
7474 num_pre_loss : int = 10 ,
75+ max_ala_epochs : int | None = 20 ,
7576 loss_fn : Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ] | None = None ,
7677 save_state : bool = True ,
7778 ) -> None :
@@ -89,6 +90,10 @@ def __init__(
8990 raise ValueError (
9091 f"num_pre_loss must be >= 1, got { num_pre_loss } "
9192 )
93+ if max_ala_epochs is not None and max_ala_epochs < 1 :
94+ raise ValueError (
95+ f"max_ala_epochs must be >= 1 or None, got { max_ala_epochs } "
96+ )
9297 if batch_size is not None and batch_size < 1 :
9398 raise ValueError (f"batch_size must be >= 1, got { batch_size } " )
9499
@@ -98,6 +103,9 @@ def __init__(
98103 self .batch_size = batch_size
99104 self .threshold = float (threshold )
100105 self .num_pre_loss = int (num_pre_loss )
106+ self .max_ala_epochs = (
107+ int (max_ala_epochs ) if max_ala_epochs is not None else None
108+ )
101109 self .loss_fn = loss_fn or nn .CrossEntropyLoss ()
102110 self .save_state = save_state
103111
@@ -469,6 +477,13 @@ def _adaptive_local_aggregation(
469477
470478 losses .append (float (loss_value .item ()))
471479 cnt += 1
480+ if self .max_ala_epochs is not None and cnt >= self .max_ala_epochs :
481+ LOGGER .info (
482+ "[Client #%d] FedALA reached max_ala_epochs=%d; stopping ALA." ,
483+ context .client_id ,
484+ self .max_ala_epochs ,
485+ )
486+ break
472487
473488 if not self .start_phase :
474489 break
@@ -525,6 +540,7 @@ class FedALAUpdateStrategyFromConfig(FedALAUpdateStrategy):
525540 - layer_idx (default: 0)
526541 - threshold (default: 0.1)
527542 - num_pre_loss (default: 10)
543+ - max_ala_epochs (default: 20)
528544 - ala_batch_size (optional)
529545 """
530546
@@ -545,6 +561,16 @@ def __init__(self) -> None:
545561 num_pre_loss = self ._get_config_value (
546562 algo , ["num_pre_loss" , "fedala_num_pre_loss" , "ala_num_pre_loss" ], 10
547563 )
564+ max_ala_epochs = self ._get_config_value (
565+ algo ,
566+ [
567+ "max_ala_epochs" ,
568+ "fedala_max_ala_epochs" ,
569+ "ala_max_epochs" ,
570+ "fedala_max_epochs" ,
571+ ],
572+ 20 ,
573+ )
548574 batch_size = self ._get_config_value (
549575 algo , ["ala_batch_size" , "fedala_batch_size" ], None
550576 )
@@ -559,6 +585,7 @@ def __init__(self) -> None:
559585 batch_size = batch_size ,
560586 threshold = float (threshold ),
561587 num_pre_loss = int (num_pre_loss ),
588+ max_ala_epochs = None if max_ala_epochs is None else int (max_ala_epochs ),
562589 )
563590
564591 @staticmethod
0 commit comments