From d101a0b533dd00e62576245b1b97631d1b7e532f Mon Sep 17 00:00:00 2001 From: Masao Taketani Date: Fri, 20 Dec 2024 02:39:06 +0900 Subject: [PATCH] fix model load for pv unsupported models --- README.md | 8 ++++++++ pyreft/reft_model.py | 16 +++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2f9ce53..68fdd75 100644 --- a/README.md +++ b/README.md @@ -271,6 +271,14 @@ reft_model = pyreft.ReftModel.load( ) ``` +> [!Warning] +> When you try to load an unsupported model by pyvene, you will get KeyError. In order to avoid that, set up a ReFT model first, and load your model in the following way. + +```py +reft_model = pyreft.ReftModel.load_pv_undefined_model( + "./reft_to_share", reft_model) +``` + ### LM training and serving with ReFT. ReFT enables intervention-based model training and serving at scale. It allows continuous batching while only keeping a single copy of the base LM. The base LM, when intervened, can solve different user tasks with batched inputs. diff --git a/pyreft/reft_model.py b/pyreft/reft_model.py index 9ff6e0d..95ccf4a 100644 --- a/pyreft/reft_model.py +++ b/pyreft/reft_model.py @@ -1,4 +1,6 @@ import pyvene as pv +import torch +import os def count_parameters(model): @@ -23,9 +25,21 @@ def _convert_to_reft_model(intervenable_model): @staticmethod def load(*args, **kwargs): - model = pv.IntervenableModel.load(*args, **kwargs) + try: + model = pv.IntervenableModel.load(*args, **kwargs) + except KeyError: + print("This model is unsupported by pyvene. Set up a reft model and use `load_pv_undefined_model` instead.") return ReftModel._convert_to_reft_model(model) + @staticmethod + def load_pv_undefined_model(load_directory, reft_model): + """ + This is a function to load a model which is unsupported by pyvene. + """ + for key in reft_model.interventions.keys(): + reft_model.interventions[key][0].load_state_dict(torch.load(os.path.join(load_directory, f"intkey_{key}.bin"), weights_only=True)) + return reft_model + def print_trainable_parameters(self): """ Print trainable parameters.