diff --git a/federatedscope/core/communication.py b/federatedscope/core/communication.py index 80dbfe078..7c71823b4 100644 --- a/federatedscope/core/communication.py +++ b/federatedscope/core/communication.py @@ -5,6 +5,7 @@ from collections import deque +from federatedscope.core.monitors.monitor import Monitor from federatedscope.core.proto import gRPC_comm_manager_pb2, \ gRPC_comm_manager_pb2_grpc from federatedscope.core.gRPC_server import gRPCComServeFunc @@ -44,6 +45,7 @@ def get_neighbors(self, neighbor_id=None): # Get all neighbors return self.neighbors + @Monitor.efficiency_comp_message_send_time def send(self, message): # All the workers share one comm_queue self.comm_queue.append(message) @@ -105,7 +107,12 @@ class gRPCCommManager(object): The implementation of gRPCCommManager is referred to the tutorial on https://grpc.io/docs/languages/python/ """ - def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None): + def __init__(self, + host='0.0.0.0', + port='50050', + client_num=2, + cfg=None, + monitor=None): self.host = host self.port = port options = [ @@ -128,7 +135,8 @@ def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None): port=port, options=options) self.neighbors = dict() - self.monitor = None # used to track the communication related metrics + self.monitor = monitor + # used to track the communication related metrics def serve(self, max_workers, host, port, options): """ @@ -169,6 +177,7 @@ def get_neighbors(self, neighbor_id=None): # Get all neighbors return self.neighbors + @Monitor.efficiency_comp_message_send_time def _send(self, receiver_address, message): def _create_stub(receiver_address): """ diff --git a/federatedscope/core/configs/cfg_evaluation.py b/federatedscope/core/configs/cfg_evaluation.py index df1a1ed97..3a77f3ca8 100644 --- a/federatedscope/core/configs/cfg_evaluation.py +++ b/federatedscope/core/configs/cfg_evaluation.py @@ -31,6 +31,14 @@ def extend_evaluation_cfg(cfg): cfg.wandb.online_track = True cfg.wandb.client_train_info = False + # ---------------------------------------------------------------------- # + # efficiency related options # This works only for FS-LLM temporarily. + # ---------------------------------------------------------------------- # + + cfg.eval.efficiency = CN() + cfg.eval.efficiency.use = False + cfg.eval.efficiency.freq = 1 + # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_evaluation_cfg) diff --git a/federatedscope/core/message.py b/federatedscope/core/message.py index 93f4bc54b..66a0bd1fe 100644 --- a/federatedscope/core/message.py +++ b/federatedscope/core/message.py @@ -3,6 +3,9 @@ from federatedscope.core.auxiliaries.utils import b64serializer from federatedscope.core.proto import gRPC_comm_manager_pb2 +from federatedscope.core.monitors.monitor import Monitor +from federatedscope.core.compression import (symmetric_uniform_quantization, + symmetric_uniform_dequantization) class Message(object): @@ -263,3 +266,55 @@ def count_bytes(self): list) else 1 upload_bytes = download_bytes * upload_cnt return download_bytes, upload_bytes + + @Monitor.efficiency_comp_message_compression_time + def quantization(content, + role=None, + model_num=None, + msg_type=None, + flag=None, + method=None, + nbits=None, + monitor=None): + if role == 'server': + if (msg_type == 'model_para' and flag and method == 'uniform'): + if model_num > 1: + content = [ + symmetric_uniform_quantization(x, nbits) + for x in content + ] + else: + content = symmetric_uniform_quantization(content, nbits) + elif role == 'client': + if method == 'uniform': + if isinstance(content, list): + content = [ + symmetric_uniform_quantization(x, nbits) + for x in content + ] + else: + content = symmetric_uniform_quantization(content, nbits) + return content + + @Monitor.efficiency_comp_message_compression_time + def dequantization(content, role=None, method=None, monitor=None): + if role == 'server': + if method == 'uniform': + if isinstance(content[1], list): # multiple model + sample_size = content[0] + quant_model = [ + symmetric_uniform_dequantization(x) for x in content[1] + ] + else: + sample_size = content[0] + quant_model = symmetric_uniform_dequantization(content[1]) + content = (sample_size, quant_model) + elif role == 'client': + if method == 'uniform': + if isinstance(content, list): # multiple model + content = [ + symmetric_uniform_dequantization(x) for x in content + ] + else: + content = symmetric_uniform_dequantization(content) + return content diff --git a/federatedscope/core/monitors/monitor.py b/federatedscope/core/monitors/monitor.py index 671541bac..9b1685174 100644 --- a/federatedscope/core/monitors/monitor.py +++ b/federatedscope/core/monitors/monitor.py @@ -5,6 +5,9 @@ import gzip import shutil import datetime +import sys +import time +import psutil from collections import defaultdict from importlib import import_module @@ -110,6 +113,132 @@ def __init__(self, cfg, monitored_object=None): "cfg.wandb.use=True but not install the wandb package") exit() + self.efficiency_memory = 0 + self.efficiency_gpu = 0 + + self.efficiency_round_start_time = 0 + self.efficiency_round_training_time = 0 + self.efficiency_total_training_time = 0 + + self.efficiency_message_compression_time = 0 + self.efficiency_message_send_time = 0 + self.efficiency_total_message_compression_time = 0 + self.efficiency_total_message_send_time = 0 + + def efficiency_compare(func): + """ + Decorate functions in trainer to get memory, gpu consumption + """ + def wrapper(*args, **kwargs): + if args[-1].monitor.cfg.eval.efficiency.use: + func(*args, **kwargs) + efficiency_memory = round( + psutil.Process(os.getpid()).memory_info().rss / 1024 / + 1024 / 1024, 2) + efficiency_gpu = torch.cuda.max_memory_allocated( + 0) / 1024 / 1024 / 1024 + + args[-1].monitor.efficiency_memory = max( + args[-1].monitor.efficiency_memory, efficiency_memory) + args[-1].monitor.efficiency_gpu = max( + args[-1].monitor.efficiency_gpu, efficiency_gpu) + else: + return func(*args, **kwargs) + + return wrapper + + def efficiency_training_start_time(func): + """ + Decorate the start function in trainer + to get the starting time for training + """ + def wrapper(*args, **kwargs): + if args[-1].monitor.cfg.eval.efficiency.use: + args[-1].monitor.efficiency_round_start_time = time.time() + return func(*args, **kwargs) + + return wrapper + + def efficiency_training_end_time(func): + """ + Decorate the end function in trainer to get the end time for training, + and get the total training time + """ + def wrapper(*args, **kwargs): + if args[-1].monitor.cfg.eval.efficiency.use: + res = func(*args, **kwargs) + args[-1].monitor.efficiency_round_training_time = time.time( + ) - args[-1].monitor.efficiency_round_start_time + args[-1].monitor.efficiency_total_training_time += args[ + -1].monitor.efficiency_round_training_time + return res + else: + return func(*args, **kwargs) + + return wrapper + + def efficiency_comp_message_compression_time(func): + """ + Decorate the message-compression functions in message.py + to get the time for message compression + """ + def wrapper(*args, **kwargs): + if kwargs['monitor'].cfg.eval.efficiency.use: + start = time.time() + res = func(*args, **kwargs) + compression_time = time.time() - start + if kwargs['monitor']: + kwargs[ + 'monitor'].efficiency_total_message_compression_time \ + += compression_time + return res + else: + return func(*args, **kwargs) + + return wrapper + + def efficiency_comp_message_send_time(func): + """ + Decorate message-sending functions in communication.py + to get the time for message sending + Note: in standalone mode, + we simulate the behavior for sening messages, i.e., + we assume the bandwidth of the network to be 100Mib/s, thus + the sending time equals to sys.getsizeof(message) / 1024 / 100 S. + """ + def wrapper(*args, **kwargs): + if args[0].monitor.cfg.eval.efficiency.use: + if args[0].monitor.cfg.federate.mode == 'standalone': + args[0].monitor.efficiency_total_message_send_time +=\ + sys.getsizeof(args[1]) / 1024 / 100 + func(*args, **kwargs) + elif args[0].monitor.cfg.federate.mode == 'distributed': + start_time = time.time() + func(*args, **kwargs) + efficiency_message_send_time = time.time() - start_time + args[0].monitor.efficiency_total_message_send_time +=\ + efficiency_message_send_time + else: + return func(*args, **kwargs) + + return wrapper + + def format_efficiency_result(self, rnd, role=-1): + res_dict = dict() + res_dict['Role'] = role + res_dict['Round'] = rnd + res_dict['Round_training_time'] = str( + self.efficiency_round_training_time) + ' S' + res_dict['Total_training_time'] = str( + self.efficiency_total_training_time) + ' S' + res_dict['Total_message_compression_time'] = str( + self.efficiency_total_message_compression_time) + ' S' + res_dict['Total_message_send_time'] = str( + self.efficiency_total_message_send_time) + ' S' + res_dict['Max memory usage'] = str(self.efficiency_memory) + ' GB' + res_dict['Max GPU usage'] = str(self.efficiency_gpu) + ' GB' + return res_dict + def eval(self, ctx): """ Evaluates the given context with ``metric_calculator``. diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index 1be53984b..53882257c 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -12,6 +12,7 @@ from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \ calculate_time_cost, add_prefix_to_path, get_ds_rank from federatedscope.core.workers.base_client import BaseClient +from federatedscope.core.monitors.monitor import Monitor logger = logging.getLogger(__name__) if get_ds_rank() == 0: @@ -175,7 +176,8 @@ def __init__(self, host=host, port=port, client_num=self._cfg.federate.client_num, - cfg=self._cfg.distribute) + cfg=self._cfg.distribute, + monitor=self._monitor) logger.info('Client: Listen to {}:{}...'.format(host, port)) self.comm_manager.add_neighbors(neighbor_id=server_id, address={ @@ -303,16 +305,22 @@ def callback_funcs_for_model_para(self, message: Message): timestamp = message.timestamp content = message.content + content = Message.dequantization( + content=content, + role='client', + method=self._cfg.quantization.method, + monitor=self._monitor) + # dequantization - if self._cfg.quantization.method == 'uniform': - from federatedscope.core.compression import \ - symmetric_uniform_dequantization - if isinstance(content, list): # multiple model - content = [ - symmetric_uniform_dequantization(x) for x in content - ] - else: - content = symmetric_uniform_dequantization(content) + # if self._cfg.quantization.method == 'uniform': + # from federatedscope.core.compression import \ + # symmetric_uniform_dequantization + # if isinstance(content, list): # multiple model + # content = [ + # symmetric_uniform_dequantization(x) for x in content + # ] + # else: + # content = symmetric_uniform_dequantization(content) # When clients share the local model, we must set strict=True to # ensure all the model params (which might be updated by other @@ -417,19 +425,25 @@ def callback_funcs_for_model_para(self, message: Message): else: shared_model_para = model_para_all + shared_model_para = Message.quantization( + content=shared_model_para, + role='client', + method=self._cfg.quantization.method, + nbits=self._cfg.quantization.nbits, + monitor=self._monitor) # quantization - if self._cfg.quantization.method == 'uniform': - from federatedscope.core.compression import \ - symmetric_uniform_quantization - nbits = self._cfg.quantization.nbits - if isinstance(shared_model_para, list): - shared_model_para = [ - symmetric_uniform_quantization(x, nbits) - for x in shared_model_para - ] - else: - shared_model_para = symmetric_uniform_quantization( - shared_model_para, nbits) + # if self._cfg.quantization.method == 'uniform': + # from federatedscope.core.compression import \ + # symmetric_uniform_quantization + # nbits = self._cfg.quantization.nbits + # if isinstance(shared_model_para, list): + # shared_model_para = [ + # symmetric_uniform_quantization(x, nbits) + # for x in shared_model_para + # ] + # else: + # shared_model_para = symmetric_uniform_quantization( + # shared_model_para, nbits) self.comm_manager.send( Message(msg_type='model_para', @@ -440,6 +454,14 @@ def callback_funcs_for_model_para(self, message: Message): init_timestamp=timestamp, instance_number=sample_size), content=(sample_size, shared_model_para))) + if ((self._cfg.eval.efficiency.use + and self._cfg.eval.efficiency.freq > 0 + and self.state % self._cfg.eval.efficiency.freq == 0) or + (self._cfg.eval.efficiency.use + and self.state == self._cfg.federate.total_round_num)): + logger.info( + self._monitor.format_efficiency_result( + rnd=self.state, role='Client #{}'.format(self.ID))) def callback_funcs_for_assign_id(self, message: Message): """ diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index c8cefb9cc..ca42918d5 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -218,7 +218,8 @@ def __init__(self, self.comm_manager = gRPCCommManager(host=host, port=port, client_num=client_num, - cfg=self._cfg.distribute) + cfg=self._cfg.distribute, + monitor=self._monitor) logger.info('Server: Listen to {}:{}...'.format(host, port)) # inject noise before broadcast @@ -713,19 +714,29 @@ def broadcast_model_para(self, model_para = {} if skip_broadcast else self.models[ 0].state_dict() + model_para = Message.quantization(content=model_para, + model_num=self.model_num, + role='server', + msg_type=msg_type, + flag=not skip_broadcast, + method=self._cfg.quantization.method, + nbits=self._cfg.quantization.nbits, + monitor=self._monitor) + # quantization - if msg_type == 'model_para' and not skip_broadcast and \ - self._cfg.quantization.method == 'uniform': - from federatedscope.core.compression import \ - symmetric_uniform_quantization - nbits = self._cfg.quantization.nbits - if self.model_num > 1: - model_para = [ - symmetric_uniform_quantization(x, nbits) - for x in model_para - ] - else: - model_para = symmetric_uniform_quantization(model_para, nbits) + # if msg_type == 'model_para' and not skip_broadcast and \ + # self._cfg.quantization.method == 'uniform': + # from federatedscope.core.compression import \ + # symmetric_uniform_quantization + # nbits = self._cfg.quantization.nbits + # if self.model_num > 1: + # model_para = [ + # symmetric_uniform_quantization(x, nbits) + # for x in model_para + # ] + # else: + # model_para = + # symmetric_uniform_quantization(model_para, nbits) # We define the evaluation happens at the end of an epoch rnd = self.state - 1 if msg_type == 'evaluate' else self.state @@ -983,19 +994,23 @@ def callback_funcs_model_para(self, message: Message): content = message.content self.sampler.change_state(sender, 'idle') + content = Message.dequantization(content=content, + role='server', + method=self._cfg.quantization.method, + monitor=self._monitor) # dequantization - if self._cfg.quantization.method == 'uniform': - from federatedscope.core.compression import \ - symmetric_uniform_dequantization - if isinstance(content[1], list): # multiple model - sample_size = content[0] - quant_model = [ - symmetric_uniform_dequantization(x) for x in content[1] - ] - else: - sample_size = content[0] - quant_model = symmetric_uniform_dequantization(content[1]) - content = (sample_size, quant_model) + # if self._cfg.quantization.method == 'uniform': + # from federatedscope.core.compression import \ + # symmetric_uniform_dequantization + # if isinstance(content[1], list): # multiple model + # sample_size = content[0] + # quant_model = [ + # symmetric_uniform_dequantization(x) for x in content[1] + # ] + # else: + # sample_size = content[0] + # quant_model = symmetric_uniform_dequantization(content[1]) + # content = (sample_size, quant_model) # update the currency timestamp according to the received message assert timestamp >= self.cur_timestamp # for test diff --git a/federatedscope/llm/trainer/trainer.py b/federatedscope/llm/trainer/trainer.py index 0470f80d7..fe2187006 100644 --- a/federatedscope/llm/trainer/trainer.py +++ b/federatedscope/llm/trainer/trainer.py @@ -24,6 +24,8 @@ def _hook_on_fit_start_numerical_precision(self, ctx): if not ctx.cfg.llm.deepspeed.use: ctx.model = ctx.model.half() + @Monitor.efficiency_compare + @Monitor.efficiency_training_start_time def _hook_on_fit_start_init(self, ctx): if ctx.cfg.llm.deepspeed.use: # Enable deepspeed @@ -60,6 +62,7 @@ def _hook_on_fit_start_init(self, ctx): ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE) ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE) + @Monitor.efficiency_compare def _hook_on_batch_forward(self, ctx): input_ids = ctx.data_batch['input_ids'].to(ctx.device) labels = ctx.data_batch['labels'].to(ctx.device) @@ -91,6 +94,7 @@ def _hook_on_batch_forward(self, ctx): ctx.loss_batch = CtxVar(loss, LIFECYCLE.BATCH) ctx.batch_size = CtxVar(len(labels), LIFECYCLE.BATCH) + @Monitor.efficiency_compare def _hook_on_batch_backward(self, ctx): if ctx.skip_this_batch: return @@ -110,6 +114,7 @@ def _hook_on_batch_backward(self, ctx): if ctx.scheduler is not None: ctx.scheduler.step() + @Monitor.efficiency_compare def _hook_on_batch_end(self, ctx): if ctx.skip_this_batch: if ctx.cfg.llm.retry_on_nan_loss: @@ -124,6 +129,8 @@ def _hook_on_batch_end(self, ctx): ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) + @Monitor.efficiency_compare + @Monitor.efficiency_training_end_time def _hook_on_fit_end(self, ctx): avg_loss = 0 if float( ctx.num_samples) == 0 else ctx.loss_batch_total / float(