22A federated learning client using pruning.
33"""
44
5+ from __future__ import annotations
6+
57import copy
68import logging
9+ from collections .abc import Mapping
10+ from typing import Any
711
812from fedsaw_algorithm import Algorithm as FedSawAlgorithm
913
@@ -20,10 +24,12 @@ class FedSawClientLifecycleStrategy(DefaultLifecycleStrategy):
2024 _STATE_KEY = "fedsaw_client"
2125
2226 @staticmethod
23- def _state (context ) :
27+ def _state (context : ClientContext ) -> dict [ str , Any ] :
2428 return context .state .setdefault (FedSawClientLifecycleStrategy ._STATE_KEY , {})
2529
26- def process_server_response (self , context , server_response ):
30+ def process_server_response (
31+ self , context : ClientContext , server_response : dict [str , Any ]
32+ ) -> None :
2733 super ().process_server_response (context , server_response )
2834 amount = server_response .get ("pruning_amount" )
2935 if amount is None :
@@ -33,17 +39,17 @@ def process_server_response(self, context, server_response):
3339 state ["pruning_amount" ] = amount
3440
3541 owner = context .owner
36- if owner is not None :
37- owner .pruning_amount = amount
42+ if isinstance ( owner , FedSawClient ) and isinstance ( amount , ( int , float )) :
43+ owner .pruning_amount = float ( amount )
3844
3945
4046class FedSawTrainingStrategy (DefaultTrainingStrategy ):
4147 """Training strategy that prunes local updates before transmission."""
4248
4349 async def train (self , context : ClientContext ):
4450 algorithm = context .algorithm
45- if algorithm is None :
46- raise RuntimeError ("Algorithm is required for FedSaw training ." )
51+ if not isinstance ( algorithm , FedSawAlgorithm ) :
52+ raise RuntimeError ("FedSaw training requires a FedSaw algorithm instance ." )
4753
4854 previous_weights = copy .deepcopy (algorithm .extract_weights ())
4955 report , new_weights = await super ().train (context )
@@ -53,25 +59,69 @@ async def train(self, context: ClientContext):
5359
5460 return report , weight_updates
5561
56- def _prune_updates (self , context , previous_weights , new_weights ):
62+ def _prune_updates (
63+ self ,
64+ context : ClientContext ,
65+ previous_weights : Mapping [str , Any ],
66+ new_weights : Mapping [str , Any ],
67+ ):
5768 algorithm = context .algorithm
69+ if not isinstance (algorithm , FedSawAlgorithm ):
70+ raise RuntimeError ("FedSaw algorithm required to prune weight updates." )
71+
5872 updates = algorithm .compute_weight_updates (previous_weights , new_weights )
5973
6074 pruning_method = (
6175 "random"
6276 if getattr (Config ().clients , "pruning_method" , None ) == "random"
6377 else "l1"
6478 )
65- pruning_amount = getattr (context .owner , "pruning_amount" , None )
79+ owner = context .owner
80+ pruning_amount : float | int | None = None
81+ if isinstance (owner , FedSawClient ):
82+ pruning_amount = owner .pruning_amount
83+
6684 if pruning_amount is None :
6785 state = FedSawClientLifecycleStrategy ._state (context )
68- pruning_amount = state .get ("pruning_amount" , 0 )
86+ stored_amount = state .get ("pruning_amount" , 0 )
87+ pruning_amount = stored_amount if isinstance (stored_amount , (int , float )) else 0
6988
7089 return algorithm .prune_weight_updates (
7190 updates , amount = pruning_amount , method = pruning_method
7291 )
7392
7493
94+ class FedSawClient (simple .Client ):
95+ """Client implementation that tracks pruning metadata for FedSaw."""
96+
97+ def __init__ (
98+ self ,
99+ model = None ,
100+ datasource = None ,
101+ algorithm = None ,
102+ trainer = None ,
103+ callbacks = None ,
104+ trainer_callbacks = None ,
105+ ):
106+ super ().__init__ (
107+ model = model ,
108+ datasource = datasource ,
109+ algorithm = algorithm or FedSawAlgorithm ,
110+ trainer = trainer ,
111+ callbacks = callbacks ,
112+ trainer_callbacks = trainer_callbacks ,
113+ )
114+ self .pruning_amount : float = 0.0
115+
116+ self ._configure_composable (
117+ lifecycle_strategy = FedSawClientLifecycleStrategy (),
118+ payload_strategy = self .payload_strategy ,
119+ training_strategy = FedSawTrainingStrategy (),
120+ reporting_strategy = self .reporting_strategy ,
121+ communication_strategy = self .communication_strategy ,
122+ )
123+
124+
75125def create_client (
76126 * ,
77127 model = None ,
@@ -80,28 +130,17 @@ def create_client(
80130 trainer = None ,
81131 callbacks = None ,
82132 trainer_callbacks = None ,
83- ):
133+ ) -> FedSawClient :
84134 """Build a FedSaw client that prunes its updates before reporting."""
85- client = simple . Client (
135+ return FedSawClient (
86136 model = model ,
87137 datasource = datasource ,
88- algorithm = algorithm or FedSawAlgorithm ,
138+ algorithm = algorithm ,
89139 trainer = trainer ,
90140 callbacks = callbacks ,
91141 trainer_callbacks = trainer_callbacks ,
92142 )
93- client .pruning_amount = 0
94-
95- client ._configure_composable (
96- lifecycle_strategy = FedSawClientLifecycleStrategy (),
97- payload_strategy = client .payload_strategy ,
98- training_strategy = FedSawTrainingStrategy (),
99- reporting_strategy = client .reporting_strategy ,
100- communication_strategy = client .communication_strategy ,
101- )
102-
103- return client
104143
105144
106145# Maintain compatibility for imports expecting a Client callable/class.
107- Client = create_client
146+ Client = FedSawClient
0 commit comments