From 1fdfb5d71af0e8b4d8ad1de79b6fbb91ff74e64c Mon Sep 17 00:00:00 2001 From: Harli Wu Date: Thu, 29 Jun 2023 06:55:30 +0800 Subject: [PATCH 1/5] Support flops calculation on LLM --- federatedscope/llm/model/adapter_builder.py | 4 +- federatedscope/llm/trainer/trainer.py | 62 +++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/federatedscope/llm/model/adapter_builder.py b/federatedscope/llm/model/adapter_builder.py index b81624fdb..52157dc34 100644 --- a/federatedscope/llm/model/adapter_builder.py +++ b/federatedscope/llm/model/adapter_builder.py @@ -134,11 +134,11 @@ def forward(self, *args, **kwargs): def generate(self, *args, **kwargs): return self.model.generate(*args, **kwargs) - def state_dict(self, return_trainable=True): + def state_dict(self, return_trainable=True, *args, **kwargs): if return_trainable: return self.get_trainable_state_dict() else: - return self.model.state_dict() + return self.model.state_dict(*args, **kwargs) def load_state_dict(self, state_dict, strict=False): return self.model.load_state_dict(state_dict, strict=False) diff --git a/federatedscope/llm/trainer/trainer.py b/federatedscope/llm/trainer/trainer.py index 42df6efdb..6914c66cb 100644 --- a/federatedscope/llm/trainer/trainer.py +++ b/federatedscope/llm/trainer/trainer.py @@ -4,6 +4,8 @@ from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.trainers.context import CtxVar from federatedscope.core.trainers.enums import LIFECYCLE +from federatedscope.core.monitors.monitor import Monitor +from federatedscope.llm.model.adapter_builder import AdapterModel logger = logging.getLogger(__name__) @@ -67,6 +69,66 @@ def _hook_on_fit_end(self, ctx): } setattr(ctx, 'eval_metrics', eval_results) + def _hook_on_batch_forward_flop_count(self, ctx): + """ + The monitoring hook to calculate the flops during the fl course + + Note: + For customized cases that the forward process is not only \ + based on ctx.model, please override this function (inheritance \ + case) or replace this hook (plug-in case) + + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.monitor`` Track average flops + ================================== =========================== + """ + if not isinstance(ctx.monitor, Monitor): + logger.warning( + f"The trainer {type(self)} does contain a valid monitor, " + f"this may be caused by initializing trainer subclasses " + f"without passing a valid monitor instance." + f"Please check whether this is you want.") + return + + if self.cfg.eval.count_flops and ctx.monitor.flops_per_sample == 0: + # calculate the flops_per_sample + try: + input_ids = ctx.data_batch['input_ids'].to(ctx.device) + labels = ctx.data_batch['labels'].to(ctx.device) + attention_mask = ctx.data_batch['attention_mask'].to( + ctx.device) + from fvcore.nn import FlopCountAnalysis + if isinstance(ctx.model, AdapterModel): + flops_one_batch = FlopCountAnalysis( + ctx.model.model, + inputs=(input_ids, attention_mask)).total() + else: + flops_one_batch = FlopCountAnalysis( + ctx.model, inputs=(input_ids, attention_mask)).total() + ctx.monitor.track_avg_flops(flops_one_batch, ctx.batch_size) + except Exception as e: + logger.info(e) + # Raise warning at the first failure + logger.warning( + "current flop count implementation is for general LLM " + "trainer case: " + "1) ctx.data_batch contains [input_ids, labels, " + "attn_mask]; and 2) the ctx.model takes first two " + "arguments should be and attention_mask. " + "If ctx.model is an adapter model, the model in 2) has " + "been replaced by ctx.model.model. " + "Please check the forward format or implement your own " + "flop_count function") + ctx.monitor.flops_per_sample = -1 + + # by default, we assume the data has the same input shape, + # thus simply multiply the flops to avoid redundant forward + ctx.monitor.total_flops += ctx.monitor.flops_per_sample * \ + ctx.batch_size + def call_llm_trainer(trainer_type): if trainer_type == 'llmtrainer': From bc125fb3b05d3a4af1212edc776f86ca3fd37025 Mon Sep 17 00:00:00 2001 From: Harli Wu Date: Fri, 14 Jul 2023 15:46:57 +0800 Subject: [PATCH 2/5] Fix bugs for human_eval --- federatedscope/llm/eval/eval_for_code/humaneval.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/federatedscope/llm/eval/eval_for_code/humaneval.py b/federatedscope/llm/eval/eval_for_code/humaneval.py index 8f1cebb0d..d8553c70d 100644 --- a/federatedscope/llm/eval/eval_for_code/humaneval.py +++ b/federatedscope/llm/eval/eval_for_code/humaneval.py @@ -31,10 +31,12 @@ def pad_spaces(s, num=4): s = " " * num + s[n:] return s - # 1. remove everything after "\n\n" - code = code.split("\n\n")[0] - # 2. remove everything after the "def " - code = code.split("def ")[0] + # 1. remove the special char \u00a0 + code = code.replace('\u00a0', '') + # 2. remove everything after the following stop sequences + # Reference: https://github.com/openai/human-eval + for stop_seq in ['\nclass', '\ndef', '\n#', '\nif', '\nprint']: + code = code.split(stop_seq)[0] # 3. pad to four space to avoid `unindent` error code = pad_spaces(code, 4) return code @@ -53,10 +55,12 @@ def main(): update_logger(init_cfg, clear_before_add=True) setup_seed(init_cfg.seed) + init_cfg.freeze() + # load your finetuned model (saved as xxx.ckpt) # in yaml file federate.save_to fschatbot = FSChatBot(init_cfg) - out_file = f'{init_cfg.federate.save_to}_humaneval_answer.jsonl' + out_file = os.path.join(init_cfg.outdir, 'humaneval_answer.jsonl') # Get test file fp = os.path.join(init_cfg.data.root, 'HumanEval.jsonl.gz') From d728b26241fe164e51f208718b74071b46fb31ee Mon Sep 17 00:00:00 2001 From: Harli Wu Date: Tue, 18 Jul 2023 21:45:58 +0800 Subject: [PATCH 3/5] Fix bugs on HumanEval --- federatedscope/llm/eval/eval_for_code/humaneval.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/federatedscope/llm/eval/eval_for_code/humaneval.py b/federatedscope/llm/eval/eval_for_code/humaneval.py index d8553c70d..013d57786 100644 --- a/federatedscope/llm/eval/eval_for_code/humaneval.py +++ b/federatedscope/llm/eval/eval_for_code/humaneval.py @@ -33,11 +33,13 @@ def pad_spaces(s, num=4): # 1. remove the special char \u00a0 code = code.replace('\u00a0', '') - # 2. remove everything after the following stop sequences + # 2. remove everything after "\n\n" + code = code.split("\n\n")[0] + # 3. remove everything after the following stop sequences # Reference: https://github.com/openai/human-eval - for stop_seq in ['\nclass', '\ndef', '\n#', '\nif', '\nprint']: + for stop_seq in ['\nclass', '\ndef', '\n#', '\nif', '\nprint', '\nassert']: code = code.split(stop_seq)[0] - # 3. pad to four space to avoid `unindent` error + # 4. pad to four space to avoid `unindent` error code = pad_spaces(code, 4) return code @@ -55,12 +57,10 @@ def main(): update_logger(init_cfg, clear_before_add=True) setup_seed(init_cfg.seed) - init_cfg.freeze() - # load your finetuned model (saved as xxx.ckpt) # in yaml file federate.save_to fschatbot = FSChatBot(init_cfg) - out_file = os.path.join(init_cfg.outdir, 'humaneval_answer.jsonl') + out_file = f'{init_cfg.federate.save_to}_humaneval_answer.jsonl' # Get test file fp = os.path.join(init_cfg.data.root, 'HumanEval.jsonl.gz') From ed96262d27036a8d3593f3bf36fe9e69e053fb37 Mon Sep 17 00:00:00 2001 From: Harli Wu Date: Tue, 18 Jul 2023 21:51:18 +0800 Subject: [PATCH 4/5] Remove \n\n in HumanEval --- federatedscope/llm/eval/eval_for_code/humaneval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/federatedscope/llm/eval/eval_for_code/humaneval.py b/federatedscope/llm/eval/eval_for_code/humaneval.py index 013d57786..e6968ff4c 100644 --- a/federatedscope/llm/eval/eval_for_code/humaneval.py +++ b/federatedscope/llm/eval/eval_for_code/humaneval.py @@ -33,8 +33,8 @@ def pad_spaces(s, num=4): # 1. remove the special char \u00a0 code = code.replace('\u00a0', '') - # 2. remove everything after "\n\n" - code = code.split("\n\n")[0] + # # 2. remove everything after "\n\n" + # code = code.split("\n\n")[0] # 3. remove everything after the following stop sequences # Reference: https://github.com/openai/human-eval for stop_seq in ['\nclass', '\ndef', '\n#', '\nif', '\nprint', '\nassert']: From f41614b8c8d91fdc61a6c22f6ef8e79088e1eeb7 Mon Sep 17 00:00:00 2001 From: Harli Wu Date: Tue, 8 Aug 2023 09:16:16 +0800 Subject: [PATCH 5/5] method-oriented offsite-tuning model generation --- federatedscope/llm/misc/fschat.py | 66 +++++++++++++++------ federatedscope/llm/offsite_tuning/server.py | 17 ++---- federatedscope/llm/offsite_tuning/utils.py | 60 ++++++++++++------- 3 files changed, 91 insertions(+), 52 deletions(-) diff --git a/federatedscope/llm/misc/fschat.py b/federatedscope/llm/misc/fschat.py index f55da9a98..d3a53271a 100644 --- a/federatedscope/llm/misc/fschat.py +++ b/federatedscope/llm/misc/fschat.py @@ -2,6 +2,8 @@ import logging import torch import transformers +import os +import gc transformers.logging.set_verbosity(40) @@ -18,29 +20,57 @@ class FSChatBot(object): def __init__(self, config): - model_name, _ = config.model.type.split('@') - self.tokenizer, _ = get_tokenizer(model_name, config.data.root, - config.llm.tok_len) - self.model = get_llm(config) + self.config = config self.device = f'cuda:{config.device}' self.add_special_tokens = True - if config.llm.offsite_tuning.use: - from federatedscope.llm.offsite_tuning.utils import \ - wrap_offsite_tuning_for_eval - self.model = wrap_offsite_tuning_for_eval(self.model, config) - else: - try: - ckpt = torch.load(config.federate.save_to, map_location='cpu') + self.prefix = [''] + self.dirname, self.filename = os.path.split(config.federate.save_to) + self.next_model() + + def next_model(self): + if hasattr(self, 'model'): + delattr(self, 'model') + gc.collect() + + model_name, _ = self.config.model.type.split('@') + self.tokenizer, _ = get_tokenizer(model_name, self.config.data.root, + self.config.llm.tok_len) + self.model = get_llm(self.config) + + self.curpfx = None + for pre in self.prefix: + if os.path.exists(os.path.join(self.dirname, pre + self.filename)): + self.curpfx = pre + break + + # Load model from the checkpoints + if self.curpfx is not None: + ckpt_path = os.path.join(self.dirname, self.curpfx + self.filename) + if self.config.llm.offsite_tuning.use: + from federatedscope.llm.offsite_tuning.utils import \ + wrap_offsite_tuning_for_eval + self.model = wrap_offsite_tuning_for_eval( + self.model, self.config, ckpt_path) + else: + ckpt = torch.load(ckpt_path, map_location='cpu') if 'model' and 'cur_round' in ckpt: self.model.load_state_dict(ckpt['model']) + logger.info( + f"Load with the model of Round {ckpt['cur_round']}") else: self.model.load_state_dict(ckpt) - except Exception as error: - print(f"{error}, will use raw model.") + logger.info(f'Model loads from the checkpoint {ckpt_path}') + + # remove the prefix up to the current one + self.prefix = self.prefix[self.prefix.index(self.curpfx) + 1:] + elif len(self.prefix) > 1: + logger.info("will use raw model.") + else: + raise ValueError('No more model is able to us') - if config.train.is_enable_half: + if self.config.train.is_enable_half: self.model.half() self.model = self.model.to(self.device) @@ -48,8 +78,8 @@ def __init__(self, config): if torch.__version__ >= "2" and sys.platform != "win32": self.model = torch.compile(self.model) - self.max_history_len = config.llm.chat.max_history_len - self.max_len = config.llm.chat.max_len + self.max_history_len = self.config.llm.chat.max_history_len + self.max_len = self.config.llm.chat.max_len self.history = [] def _build_prompt(self, input_text): @@ -123,8 +153,8 @@ def main(): setup_seed(init_cfg.seed) chat_bot = FSChatBot(init_cfg) - welcome = "Welcome to FSChatBot," \ - "`clear` to clear history," \ + welcome = "Welcome to FSChatBot, " \ + "`clear` to clear history, " \ "`quit` to end chat." print(welcome) while True: diff --git a/federatedscope/llm/offsite_tuning/server.py b/federatedscope/llm/offsite_tuning/server.py index 264e73e4b..66a6fe8d5 100644 --- a/federatedscope/llm/offsite_tuning/server.py +++ b/federatedscope/llm/offsite_tuning/server.py @@ -9,7 +9,7 @@ from federatedscope.core.workers.server import Server from federatedscope.llm.offsite_tuning.utils import \ - generate_emulator_and_adapter, align_student_with_teacher + generate_adap_model, align_student_with_teacher logger = logging.getLogger(__name__) @@ -30,17 +30,8 @@ def __init__(self, device='cpu', strategy=None, **kwargs): - compress_strategy = config.llm.offsite_tuning.strategy - emulator_l = config.llm.offsite_tuning.emu_l - emulator_r = config.llm.offsite_tuning.emu_r - offsite_tuning_kwargs = config.llm.offsite_tuning.kwargs[0] logger.info('Server: Generating emulator and adapter...') - adap_model = \ - generate_emulator_and_adapter(model, - strategy=compress_strategy, - emulator_l=emulator_l, - emulator_r=emulator_r, - **offsite_tuning_kwargs) + adap_model = generate_adap_model(model, config.llm.offsite_tuning) # Emulator alignment if config.llm.offsite_tuning.emu_align.use: adap_model = align_student_with_teacher(raw_model=model, @@ -54,7 +45,11 @@ def __init__(self, os._exit(0) # No need for this attr if hasattr(adap_model, 'teacher'): + import gc + import torch del adap_model.teacher + gc.collect() + torch.cuda.empty_cache() self.raw_model = model super(OffsiteTuningServer, diff --git a/federatedscope/llm/offsite_tuning/utils.py b/federatedscope/llm/offsite_tuning/utils.py index 5aff9d172..b71cd4a30 100644 --- a/federatedscope/llm/offsite_tuning/utils.py +++ b/federatedscope/llm/offsite_tuning/utils.py @@ -95,7 +95,7 @@ def get_layers(adapter_model): return layers -def set_layers(adapter_model, layers, emu_l=0, emu_r=-1): +def set_layers(adapter_model, layers): if isinstance(adapter_model.model, OPTForCausalLM): adapter_model.model.model.decoder.layers = layers elif isinstance(adapter_model.model, GPT2LMHeadModel): @@ -109,12 +109,6 @@ def set_layers(adapter_model, layers, emu_l=0, emu_r=-1): logger.warning(f'Model {type(adapter_model.model)} not support, ' f'use default setting.') adapter_model.model.transformer.h = layers - adapter_model.student = layers[emu_l:emu_r] - adapter_model.adapter = layers[:emu_l] + layers[emu_r:] - add_prologue(adapter_model.student[0], None) - add_epilogue(adapter_model.student[-1], None) - adapter_model.student_l = adapter_model.student[0] - adapter_model.student_r = adapter_model.student[-1] return adapter_model @@ -152,13 +146,31 @@ def model_distillation(model, **kwargs): } +def generate_adap_model(model: AdapterModel, offsite_tuning_cfg): + if offsite_tuning_cfg.strategy in COMP_FUNC_MAPPING.keys(): + compress_strategy = offsite_tuning_cfg.strategy + emulator_l = offsite_tuning_cfg.emu_l + emulator_r = offsite_tuning_cfg.emu_r + emu_align = offsite_tuning_cfg.emu_align.use + offsite_tuning_kwargs = offsite_tuning_cfg.kwargs[0] + return generate_emulator_and_adapter(model, + strategy=compress_strategy, + emulator_l=emulator_l, + emulator_r=emulator_r, + emulator_alignment=emu_align, + **offsite_tuning_kwargs) + else: + raise NotImplementedError + + def generate_emulator_and_adapter(model: AdapterModel, strategy='drop_layer', - emulator_l=1, + emulator_l=0, emulator_r=1000, + emulator_alignment=False, **kwargs): layers = get_layers(model) - l, r = max(emulator_l, 1), min(emulator_r, len(layers) - 1) + l, r = max(emulator_l, 0), min(emulator_r, len(layers) - 1) # Set the to-compress part untrainable for layer in layers[l:r]: @@ -186,7 +198,14 @@ def generate_emulator_and_adapter(model: AdapterModel, new_model = copy.deepcopy(model) # Set student model - new_model = set_layers(new_model, emulator_and_adapter, l, r) + new_model = set_layers(new_model, emulator_and_adapter) + + if emulator_alignment: + new_model.student = layers + add_prologue(new_model.student[0], None) + add_epilogue(new_model.student[-1], None) + new_model.student_l = new_model.student[0] + new_model.student_r = new_model.student[-1] gc.collect() torch.cuda.empty_cache() @@ -303,20 +322,11 @@ def build_cfg_for_alignment(config): return adap_model -def wrap_offsite_tuning_for_eval(model, config): +def wrap_offsite_tuning_for_eval(model, config, ckpt_path=None): logger.info('===============use offsite tuning===============') # We use offsite-tuning in this experiment # Use adapter model instead - compress_strategy = config.llm.offsite_tuning.strategy - emulator_l = config.llm.offsite_tuning.emu_l - emulator_r = config.llm.offsite_tuning.emu_r - offsite_tuning_kwargs = config.llm.offsite_tuning.kwargs[0] - adap_model = \ - generate_emulator_and_adapter(model, - strategy=compress_strategy, - emulator_l=emulator_l, - emulator_r=emulator_r, - **offsite_tuning_kwargs) + adap_model = generate_adap_model(model, config.llm.offsite_tuning) # Load kd model if ckpt exits if config.llm.offsite_tuning.emu_align.use and \ config.llm.offsite_tuning.eval_type == 'emu': @@ -333,9 +343,12 @@ def wrap_offsite_tuning_for_eval(model, config): # Load ckpt for eval try: - ckpt = torch.load(config.federate.save_to, map_location='cpu') + if ckpt_path is None: + ckpt_path = config.federate.save_to + ckpt = torch.load(ckpt_path, map_location='cpu') if 'model' and 'cur_round' in ckpt: adap_model.load_state_dict(ckpt['model']) + logger.info(f"Load with the model of Round {ckpt['cur_round']}") else: adap_model.load_state_dict(ckpt) except Exception as error: @@ -343,7 +356,8 @@ def wrap_offsite_tuning_for_eval(model, config): if config.llm.offsite_tuning.eval_type == 'emu': model = adap_model - del model.teacher + if hasattr(model, 'teacher'): + del model.teacher elif config.llm.offsite_tuning.eval_type == 'full': # Raw model load adapter from adapter_and_emulator new_model_state_dict = model.state_dict()