diff --git a/federatedscope/llm/misc/fschat.py b/federatedscope/llm/misc/fschat.py index f55da9a98..d3a53271a 100644 --- a/federatedscope/llm/misc/fschat.py +++ b/federatedscope/llm/misc/fschat.py @@ -2,6 +2,8 @@ import logging import torch import transformers +import os +import gc transformers.logging.set_verbosity(40) @@ -18,29 +20,57 @@ class FSChatBot(object): def __init__(self, config): - model_name, _ = config.model.type.split('@') - self.tokenizer, _ = get_tokenizer(model_name, config.data.root, - config.llm.tok_len) - self.model = get_llm(config) + self.config = config self.device = f'cuda:{config.device}' self.add_special_tokens = True - if config.llm.offsite_tuning.use: - from federatedscope.llm.offsite_tuning.utils import \ - wrap_offsite_tuning_for_eval - self.model = wrap_offsite_tuning_for_eval(self.model, config) - else: - try: - ckpt = torch.load(config.federate.save_to, map_location='cpu') + self.prefix = [''] + self.dirname, self.filename = os.path.split(config.federate.save_to) + self.next_model() + + def next_model(self): + if hasattr(self, 'model'): + delattr(self, 'model') + gc.collect() + + model_name, _ = self.config.model.type.split('@') + self.tokenizer, _ = get_tokenizer(model_name, self.config.data.root, + self.config.llm.tok_len) + self.model = get_llm(self.config) + + self.curpfx = None + for pre in self.prefix: + if os.path.exists(os.path.join(self.dirname, pre + self.filename)): + self.curpfx = pre + break + + # Load model from the checkpoints + if self.curpfx is not None: + ckpt_path = os.path.join(self.dirname, self.curpfx + self.filename) + if self.config.llm.offsite_tuning.use: + from federatedscope.llm.offsite_tuning.utils import \ + wrap_offsite_tuning_for_eval + self.model = wrap_offsite_tuning_for_eval( + self.model, self.config, ckpt_path) + else: + ckpt = torch.load(ckpt_path, map_location='cpu') if 'model' and 'cur_round' in ckpt: self.model.load_state_dict(ckpt['model']) + logger.info( + f"Load with the model of Round {ckpt['cur_round']}") else: self.model.load_state_dict(ckpt) - except Exception as error: - print(f"{error}, will use raw model.") + logger.info(f'Model loads from the checkpoint {ckpt_path}') + + # remove the prefix up to the current one + self.prefix = self.prefix[self.prefix.index(self.curpfx) + 1:] + elif len(self.prefix) > 1: + logger.info("will use raw model.") + else: + raise ValueError('No more model is able to us') - if config.train.is_enable_half: + if self.config.train.is_enable_half: self.model.half() self.model = self.model.to(self.device) @@ -48,8 +78,8 @@ def __init__(self, config): if torch.__version__ >= "2" and sys.platform != "win32": self.model = torch.compile(self.model) - self.max_history_len = config.llm.chat.max_history_len - self.max_len = config.llm.chat.max_len + self.max_history_len = self.config.llm.chat.max_history_len + self.max_len = self.config.llm.chat.max_len self.history = [] def _build_prompt(self, input_text): @@ -123,8 +153,8 @@ def main(): setup_seed(init_cfg.seed) chat_bot = FSChatBot(init_cfg) - welcome = "Welcome to FSChatBot," \ - "`clear` to clear history," \ + welcome = "Welcome to FSChatBot, " \ + "`clear` to clear history, " \ "`quit` to end chat." print(welcome) while True: diff --git a/federatedscope/llm/offsite_tuning/server.py b/federatedscope/llm/offsite_tuning/server.py index ad0a0dd8f..8741760ca 100644 --- a/federatedscope/llm/offsite_tuning/server.py +++ b/federatedscope/llm/offsite_tuning/server.py @@ -9,7 +9,7 @@ from federatedscope.core.workers.server import Server from federatedscope.llm.offsite_tuning.utils import \ - generate_emulator_and_adapter, align_student_with_teacher + generate_adap_model, align_student_with_teacher logger = logging.getLogger(__name__) @@ -30,17 +30,8 @@ def __init__(self, device='cpu', strategy=None, **kwargs): - compress_strategy = config.llm.offsite_tuning.strategy - emulator_l = config.llm.offsite_tuning.emu_l - emulator_r = config.llm.offsite_tuning.emu_r - offsite_tuning_kwargs = config.llm.offsite_tuning.kwargs[0] logger.info('Server: Generating emulator and adapter...') - adap_model = \ - generate_emulator_and_adapter(model, - strategy=compress_strategy, - emulator_l=emulator_l, - emulator_r=emulator_r, - **offsite_tuning_kwargs) + adap_model = generate_adap_model(model, config.llm.offsite_tuning) # Emulator alignment if config.llm.offsite_tuning.emu_align.use: adap_model = align_student_with_teacher(raw_model=model, @@ -54,7 +45,11 @@ def __init__(self, os._exit(0) # No need for this attr if hasattr(adap_model, 'teacher'): + import gc + import torch del adap_model.teacher + gc.collect() + torch.cuda.empty_cache() self.raw_model = model super(OffsiteTuningServer, diff --git a/federatedscope/llm/offsite_tuning/utils.py b/federatedscope/llm/offsite_tuning/utils.py index 5aff9d172..b71cd4a30 100644 --- a/federatedscope/llm/offsite_tuning/utils.py +++ b/federatedscope/llm/offsite_tuning/utils.py @@ -95,7 +95,7 @@ def get_layers(adapter_model): return layers -def set_layers(adapter_model, layers, emu_l=0, emu_r=-1): +def set_layers(adapter_model, layers): if isinstance(adapter_model.model, OPTForCausalLM): adapter_model.model.model.decoder.layers = layers elif isinstance(adapter_model.model, GPT2LMHeadModel): @@ -109,12 +109,6 @@ def set_layers(adapter_model, layers, emu_l=0, emu_r=-1): logger.warning(f'Model {type(adapter_model.model)} not support, ' f'use default setting.') adapter_model.model.transformer.h = layers - adapter_model.student = layers[emu_l:emu_r] - adapter_model.adapter = layers[:emu_l] + layers[emu_r:] - add_prologue(adapter_model.student[0], None) - add_epilogue(adapter_model.student[-1], None) - adapter_model.student_l = adapter_model.student[0] - adapter_model.student_r = adapter_model.student[-1] return adapter_model @@ -152,13 +146,31 @@ def model_distillation(model, **kwargs): } +def generate_adap_model(model: AdapterModel, offsite_tuning_cfg): + if offsite_tuning_cfg.strategy in COMP_FUNC_MAPPING.keys(): + compress_strategy = offsite_tuning_cfg.strategy + emulator_l = offsite_tuning_cfg.emu_l + emulator_r = offsite_tuning_cfg.emu_r + emu_align = offsite_tuning_cfg.emu_align.use + offsite_tuning_kwargs = offsite_tuning_cfg.kwargs[0] + return generate_emulator_and_adapter(model, + strategy=compress_strategy, + emulator_l=emulator_l, + emulator_r=emulator_r, + emulator_alignment=emu_align, + **offsite_tuning_kwargs) + else: + raise NotImplementedError + + def generate_emulator_and_adapter(model: AdapterModel, strategy='drop_layer', - emulator_l=1, + emulator_l=0, emulator_r=1000, + emulator_alignment=False, **kwargs): layers = get_layers(model) - l, r = max(emulator_l, 1), min(emulator_r, len(layers) - 1) + l, r = max(emulator_l, 0), min(emulator_r, len(layers) - 1) # Set the to-compress part untrainable for layer in layers[l:r]: @@ -186,7 +198,14 @@ def generate_emulator_and_adapter(model: AdapterModel, new_model = copy.deepcopy(model) # Set student model - new_model = set_layers(new_model, emulator_and_adapter, l, r) + new_model = set_layers(new_model, emulator_and_adapter) + + if emulator_alignment: + new_model.student = layers + add_prologue(new_model.student[0], None) + add_epilogue(new_model.student[-1], None) + new_model.student_l = new_model.student[0] + new_model.student_r = new_model.student[-1] gc.collect() torch.cuda.empty_cache() @@ -303,20 +322,11 @@ def build_cfg_for_alignment(config): return adap_model -def wrap_offsite_tuning_for_eval(model, config): +def wrap_offsite_tuning_for_eval(model, config, ckpt_path=None): logger.info('===============use offsite tuning===============') # We use offsite-tuning in this experiment # Use adapter model instead - compress_strategy = config.llm.offsite_tuning.strategy - emulator_l = config.llm.offsite_tuning.emu_l - emulator_r = config.llm.offsite_tuning.emu_r - offsite_tuning_kwargs = config.llm.offsite_tuning.kwargs[0] - adap_model = \ - generate_emulator_and_adapter(model, - strategy=compress_strategy, - emulator_l=emulator_l, - emulator_r=emulator_r, - **offsite_tuning_kwargs) + adap_model = generate_adap_model(model, config.llm.offsite_tuning) # Load kd model if ckpt exits if config.llm.offsite_tuning.emu_align.use and \ config.llm.offsite_tuning.eval_type == 'emu': @@ -333,9 +343,12 @@ def wrap_offsite_tuning_for_eval(model, config): # Load ckpt for eval try: - ckpt = torch.load(config.federate.save_to, map_location='cpu') + if ckpt_path is None: + ckpt_path = config.federate.save_to + ckpt = torch.load(ckpt_path, map_location='cpu') if 'model' and 'cur_round' in ckpt: adap_model.load_state_dict(ckpt['model']) + logger.info(f"Load with the model of Round {ckpt['cur_round']}") else: adap_model.load_state_dict(ckpt) except Exception as error: @@ -343,7 +356,8 @@ def wrap_offsite_tuning_for_eval(model, config): if config.llm.offsite_tuning.eval_type == 'emu': model = adap_model - del model.teacher + if hasattr(model, 'teacher'): + del model.teacher elif config.llm.offsite_tuning.eval_type == 'full': # Raw model load adapter from adapter_and_emulator new_model_state_dict = model.state_dict()