-
Notifications
You must be signed in to change notification settings - Fork 252
Expand file tree
/
Copy pathclient.py
More file actions
537 lines (485 loc) · 23.4 KB
/
client.py
File metadata and controls
537 lines (485 loc) · 23.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
import copy
import logging
import sys
import pickle
from federatedscope.core.auxiliaries.enums import STAGE
from federatedscope.core.message import Message
from federatedscope.core.communication import StandaloneCommManager, \
gRPCCommManager
from federatedscope.core.monitors.early_stopper import EarlyStopper
from federatedscope.core.workers import Worker
from federatedscope.core.auxiliaries.trainer_builder import get_trainer
from federatedscope.core.secret_sharing import AdditiveSecretSharing
from federatedscope.core.auxiliaries.utils import merge_dict, \
calculate_time_cost
logger = logging.getLogger(__name__)
class Client(Worker):
"""
The Client class, which describes the behaviors of client in an FL course.
The behaviors are described by the handling functions (named as
callback_funcs_for_xxx)
Arguments:
ID: The unique ID of the client, which is assigned by the server
when joining the FL course
server_id: (Default) 0
state: The training round
config: The configuration
data: The data owned by the client
model: The model maintained locally
device: The device to run local training and evaluation
strategy: redundant attribute
"""
def __init__(self,
ID=-1,
server_id=None,
state=-1,
config=None,
data=None,
model=None,
device='cpu',
strategy=None,
is_unseen_client=False,
*args,
**kwargs):
super(Client, self).__init__(ID, state, config, model, strategy)
# the unseen_client indicates that whether this client contributes to
# FL process by training on its local data and uploading the local
# model update, which is useful for check the participation
# generalization gap in
# [ICLR'22, What Do We Mean by Generalization in Federated Learning?]
self.is_unseen_client = is_unseen_client
# Attack only support the stand alone model;
# Check if is a attacker; a client is a attacker if the
# config.attack.attack_method is provided
self.is_attacker = config.attack.attacker_id == ID and \
config.attack.attack_method != '' and \
config.federate.mode == 'standalone'
# Build Trainer
# trainer might need configurations other than those of trainer node
self.trainer = get_trainer(model=model,
data=data,
device=device,
config=self._cfg,
is_attacker=self.is_attacker,
monitor=self._monitor)
# For client-side evaluation
self.best_results = dict()
self.history_results = dict()
# in local or global training mode, we do use the early stopper.
# Otherwise, we set patience=0 to deactivate the local early-stopper
patience = self._cfg.early_stop.patience if \
self._cfg.federate.method in [
"local", "global"
] else 0
self.early_stopper = EarlyStopper(
patience, self._cfg.early_stop.delta,
self._cfg.early_stop.improve_indicator_mode,
self._cfg.early_stop.the_smaller_the_better)
# Secret Sharing Manager and message buffer
self.ss_manager = AdditiveSecretSharing(
shared_party_num=int(self._cfg.federate.sample_client_num
)) if self._cfg.federate.use_ss else None
self.msg_buffer = {'train': dict(), 'eval': dict()}
# Register message handlers
self.msg_handlers = dict()
self._register_default_handlers()
# Communication and communication ability
if 'resource_info' in kwargs and kwargs['resource_info'] is not None:
self.comp_speed = float(
kwargs['resource_info']['computation']) / 1000. # (s/sample)
self.comm_bandwidth = float(
kwargs['resource_info']['communication']) # (kbit/s)
else:
self.comp_speed = None
self.comm_bandwidth = None
self.model_size = sys.getsizeof(pickle.dumps(
self.model)) / 1024.0 * 8. # kbits
# Initialize communication manager
self.server_id = server_id
if self.mode == 'standalone':
comm_queue = kwargs['shared_comm_queue']
self.comm_manager = StandaloneCommManager(comm_queue=comm_queue,
monitor=self._monitor)
self.local_address = None
elif self.mode == 'distributed':
host = kwargs['host']
port = kwargs['port']
server_host = kwargs['server_host']
server_port = kwargs['server_port']
self.comm_manager = gRPCCommManager(
host=host, port=port, client_num=self._cfg.federate.client_num)
logger.info('Client: Listen to {}:{}...'.format(host, port))
self.comm_manager.add_neighbors(neighbor_id=server_id,
address={
'host': server_host,
'port': server_port
})
self.local_address = {
'host': self.comm_manager.host,
'port': self.comm_manager.port
}
def _gen_timestamp(self, init_timestamp, instance_number):
if init_timestamp is None:
return None
comp_cost, comm_cost = calculate_time_cost(
instance_number=instance_number,
comm_size=self.model_size,
comp_speed=self.comp_speed,
comm_bandwidth=self.comm_bandwidth)
return init_timestamp + comp_cost + comm_cost
def _calculate_model_delta(self, init_model, updated_model):
if not isinstance(init_model, list):
init_model = [init_model]
updated_model = [updated_model]
model_deltas = list()
for model_index in range(len(init_model)):
model_delta = copy.deepcopy(init_model[model_index])
for key in init_model[model_index].keys():
model_delta[key] = updated_model[model_index][
key] - init_model[model_index][key]
model_deltas.append(model_delta)
if len(model_deltas) > 1:
return model_deltas
else:
return model_deltas[0]
def register_handlers(self, msg_type, callback_func):
"""
To bind a message type with a handling function.
Arguments:
msg_type (str): The defined message type
callback_func: The handling functions to handle the received
message
"""
self.msg_handlers[msg_type] = callback_func
def _register_default_handlers(self):
self.register_handlers('assign_client_id',
self.callback_funcs_for_assign_id)
self.register_handlers('ask_for_join_in_info',
self.callback_funcs_for_join_in_info)
self.register_handlers('address', self.callback_funcs_for_address)
self.register_handlers('model_para',
self.callback_funcs_for_model_para)
self.register_handlers('ss_model_para',
self.callback_funcs_for_model_para)
self.register_handlers('evaluate', self.callback_funcs_for_evaluate)
self.register_handlers('finish', self.callback_funcs_for_finish)
self.register_handlers('converged', self.callback_funcs_for_converged)
def join_in(self):
"""
To send 'join_in' message to the server for joining in the FL course.
"""
self.comm_manager.send(
Message(msg_type='join_in',
sender=self.ID,
receiver=[self.server_id],
timestamp=0,
content=self.local_address))
def run(self):
"""
To listen to the message and handle them accordingly (used for
distributed mode)
"""
while True:
msg = self.comm_manager.receive()
if self.state <= msg.state:
self.msg_handlers[msg.msg_type](msg)
if msg.msg_type == 'finish':
break
def callback_funcs_for_model_para(self, message: Message):
"""
The handling function for receiving model parameters,
which triggers the local training process.
This handling function is widely used in various FL courses.
Arguments:
message: The received message, which includes sender, receiver,
state, and content.
More detail can be found in federatedscope.core.message
"""
if 'ss' in message.msg_type:
# A fragment of the shared secret
state, content, timestamp = message.state, message.content, \
message.timestamp
self.msg_buffer[STAGE.TRAIN][state].append(content)
if len(self.msg_buffer[STAGE.TRAIN]
[state]) == self._cfg.federate.client_num:
# Check whether the received fragments are enough
model_list = self.msg_buffer[STAGE.TRAIN][state]
sample_size, first_aggregate_model_para = model_list[0]
single_model_case = True
if isinstance(first_aggregate_model_para, list):
assert isinstance(first_aggregate_model_para[0], dict), \
"aggregate_model_para should a list of multiple " \
"state_dict for multiple models"
single_model_case = False
else:
assert isinstance(first_aggregate_model_para, dict), \
"aggregate_model_para should " \
"a state_dict for single model case"
first_aggregate_model_para = [first_aggregate_model_para]
model_list = [[model] for model in model_list]
for sub_model_idx, aggregate_single_model_para in enumerate(
first_aggregate_model_para):
for key in aggregate_single_model_para:
for i in range(1, len(model_list)):
aggregate_single_model_para[key] += model_list[i][
sub_model_idx][key]
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=[self.server_id],
state=self.state,
timestamp=timestamp,
content=(sample_size, first_aggregate_model_para[0]
if single_model_case else
first_aggregate_model_para)))
else:
round = message.state
sender = message.sender
timestamp = message.timestamp
content = message.content
# When clients share the local model, we must set strict=True to
# ensure all the model params (which might be updated by other
# clients in the previous local training process) are overwritten
# and synchronized with the received model
self.trainer.update(content,
strict=self._cfg.federate.share_local_model)
self.state = round
skip_train_isolated_or_global_mode = \
self.early_stopper.early_stopped and \
self._cfg.federate.method in ["local", "global"]
if self.is_unseen_client or skip_train_isolated_or_global_mode:
# for these cases (1) unseen client (2) isolated_global_mode,
# we do not local train and upload local model
sample_size, model_para_all, results = \
0, self.trainer.get_model_para(), {}
if skip_train_isolated_or_global_mode:
logger.info(
f"[Local/Global mode] Client #{self.ID} has been "
f"early stopped, we will skip the local training")
self._monitor.local_converged()
else:
if self.early_stopper.early_stopped and \
self._monitor.local_convergence_round == 0:
logger.info(
f"[Normal FL Mode] Client #{self.ID} has been locally "
f"early stopped. "
f"The next FL update may result in negative effect")
self._monitor.local_converged()
sample_size, model_para_all, results = self.trainer.train()
if self._cfg.federate.share_local_model and not \
self._cfg.federate.online_aggr:
model_para_all = copy.deepcopy(model_para_all)
train_log_res = self._monitor.format_eval_res(
results,
rnd=self.state,
role='Client #{}'.format(self.ID),
return_raw=True)
logger.info(train_log_res)
if self._cfg.wandb.use and self._cfg.wandb.client_train_info:
self._monitor.save_formatted_results(train_log_res,
save_file_name="")
# Return the feedbacks to the server after local update
if self._cfg.federate.use_ss:
assert not self.is_unseen_client, \
"Un-support using secret sharing for unseen clients." \
"i.e., you set cfg.federate.use_ss=True and " \
"cfg.federate.unseen_clients_rate in (0, 1)"
single_model_case = True
if isinstance(model_para_all, list):
assert isinstance(model_para_all[0], dict), \
"model_para should a list of " \
"multiple state_dict for multiple models"
single_model_case = False
else:
assert isinstance(model_para_all, dict), \
"model_para should a state_dict for single model case"
model_para_all = [model_para_all]
model_para_list_all = []
for model_para in model_para_all:
for key in model_para:
model_para[key] = model_para[key] * sample_size
model_para_list = self.ss_manager.secret_split(model_para)
model_para_list_all.append(model_para_list)
# print(model_para)
# print(self.ss_manager.secret_reconstruct(
# model_para_list))
frame_idx = 0
for neighbor in self.comm_manager.neighbors:
if neighbor != self.server_id:
content_frame = model_para_list_all[0][frame_idx] if \
single_model_case else \
[model_para_list[frame_idx] for model_para_list
in model_para_list_all]
self.comm_manager.send(
Message(msg_type='ss_model_para',
sender=self.ID,
receiver=[neighbor],
state=self.state,
timestamp=self._gen_timestamp(
init_timestamp=timestamp,
instance_number=sample_size),
content=content_frame))
frame_idx += 1
content_frame = model_para_list_all[0][frame_idx] if \
single_model_case else \
[model_para_list[frame_idx] for model_para_list in
model_para_list_all]
self.msg_buffer[STAGE.TRAIN][self.state] = [(sample_size,
content_frame)]
else:
if self._cfg.asyn.use:
# Return the model delta when using asynchronous training
# protocol, because the staled updated might be discounted
# and cause that the sum of the aggregated weights might
# not be equal to 1
shared_model_para = self._calculate_model_delta(
init_model=content, updated_model=model_para_all)
else:
shared_model_para = model_para_all
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=[sender],
state=self.state,
timestamp=self._gen_timestamp(
init_timestamp=timestamp,
instance_number=sample_size),
content=(sample_size, shared_model_para)))
def callback_funcs_for_assign_id(self, message: Message):
"""
The handling function for receiving the client_ID assigned by the
server (during the joining process),
which is used in the distributed mode.
Arguments:
message: The received message
"""
content = message.content
self.ID = int(content)
logger.info('Client (address {}:{}) is assigned with #{:d}.'.format(
self.comm_manager.host, self.comm_manager.port, self.ID))
def callback_funcs_for_join_in_info(self, message: Message):
"""
The handling function for receiving the request of join in information
(such as batch_size, num_of_samples) during the joining process.
Arguments:
message: The received message
"""
requirements = message.content
timestamp = message.timestamp
join_in_info = dict()
for requirement in requirements:
if requirement.lower() == 'num_sample':
if self._cfg.train.batch_or_epoch == 'batch':
num_sample = self._cfg.train.local_update_steps * \
self._cfg.data.batch_size
else:
num_sample = self._cfg.train.local_update_steps * \
self.trainer.ctx.num_train_batch
join_in_info['num_sample'] = num_sample
if self._cfg.trainer.type == 'nodefullbatch_trainer':
join_in_info['num_sample'] = \
self.trainer.ctx.data.x.shape[0]
elif requirement.lower() == 'client_resource':
assert self.comm_bandwidth is not None and self.comp_speed \
is not None, "The requirement join_in_info " \
"'client_resource' does not exist."
join_in_info['client_resource'] = self.model_size / \
self.comm_bandwidth + self.comp_speed
else:
raise ValueError(
'Fail to get the join in information with type {}'.format(
requirement))
self.comm_manager.send(
Message(msg_type='join_in_info',
sender=self.ID,
receiver=[self.server_id],
state=self.state,
timestamp=timestamp,
content=join_in_info))
def callback_funcs_for_address(self, message: Message):
"""
The handling function for receiving other clients' IP addresses,
which is used for constructing a complex topology
Arguments:
message: The received message
"""
content = message.content
for neighbor_id, address in content.items():
if int(neighbor_id) != self.ID:
self.comm_manager.add_neighbors(neighbor_id, address)
def callback_funcs_for_evaluate(self, message: Message):
"""
The handling function for receiving the request of evaluating
Arguments:
message: The received message
"""
sender, timestamp = message.sender, message.timestamp
self.state = message.state
if message.content is not None:
self.trainer.update(message.content,
strict=self._cfg.federate.share_local_model)
if self.early_stopper.early_stopped and self._cfg.federate.method in [
"local", "global"
]:
metrics = list(self.best_results.values())[0]
else:
metrics = {}
if self._cfg.finetune.before_eval:
self.trainer.finetune()
for split in self._cfg.eval.split:
# TODO: The time cost of evaluation is not considered here
eval_metrics = self.trainer.evaluate(
target_data_split_name=split)
if self._cfg.federate.mode == 'distributed':
logger.info(
self._monitor.format_eval_res(eval_metrics,
rnd=self.state,
role='Client #{}'.format(
self.ID),
return_raw=True))
metrics.update(**eval_metrics)
formatted_eval_res = self._monitor.format_eval_res(
metrics,
rnd=self.state,
role='Client #{}'.format(self.ID),
forms='raw',
return_raw=True)
self._monitor.update_best_result(
self.best_results,
formatted_eval_res['Results_raw'],
results_type=f"client #{self.ID}",
round_wise_update_key=self._cfg.eval.
best_res_update_round_wise_key)
self.history_results = merge_dict(
self.history_results, formatted_eval_res['Results_raw'])
self.early_stopper.track_and_check(self.history_results[
self._cfg.eval.best_res_update_round_wise_key])
self.comm_manager.send(
Message(msg_type='metrics',
sender=self.ID,
receiver=[sender],
state=self.state,
timestamp=timestamp,
content=metrics))
def callback_funcs_for_finish(self, message: Message):
"""
The handling function for receiving the signal of finishing the FL
course.
Arguments:
message: The received message
"""
logger.info(
f"================= client {self.ID} received finish message "
f"=================")
if message.content is not None:
self.trainer.update(message.content,
strict=self._cfg.federate.share_local_model)
self._monitor.finish_fl()
def callback_funcs_for_converged(self, message: Message):
"""
The handling function for receiving the signal that the FL course
converged
Arguments:
message: The received message
"""
self._monitor.global_converged()