Skip to content

Commit db7c676

Browse files
authored
Merge pull request #135 from jiudingsun01/main
[Minor] Support MistralModel and MistralForCausalLM
2 parents b57b660 + bf09440 commit db7c676

File tree

3 files changed

+83
-8
lines changed

3 files changed

+83
-8
lines changed

pyvene/models/intervenable_modelcard.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from .constants import *
22
from .llama.modelings_intervenable_llama import *
3+
from .mistral.modellings_intervenable_mistral import *
34
from .gpt2.modelings_intervenable_gpt2 import *
45
from .gpt_neo.modelings_intervenable_gpt_neo import *
56
from .gpt_neox.modelings_intervenable_gpt_neox import *
67
from .mlp.modelings_intervenable_mlp import *
78
from .gru.modelings_intervenable_gru import *
89
from .blip.modelings_intervenable_blip import *
9-
from .blip.modelings_intervenable_blip_itm import *
1010
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *
1111

1212

@@ -21,7 +21,6 @@
2121

2222
import transformers.models as hf_models
2323
from .blip.modelings_blip import BlipWrapper
24-
from .blip.modelings_blip_itm import BlipITMWrapper
2524
from .mlp.modelings_mlp import MLPModel, MLPForClassification
2625
from .gru.modelings_gru import GRUModel, GRULMHeadModel, GRUForClassification
2726
from .backpack_gpt2.modelings_backpack_gpt2 import BackpackGPT2LMHeadModel
@@ -35,17 +34,16 @@
3534
type_to_module_mapping = {
3635
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_module_mapping,
3736
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_module_mapping,
38-
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_module_mapping,
3937
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping,
4038
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping,
4139
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping,
4240
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_module_mapping,
4341
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_module_mapping,
4442
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_module_mapping,
43+
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_module_mapping,
44+
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_module_mapping,
4545
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
46-
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping,
4746
BlipWrapper: blip_wrapper_type_to_module_mapping,
48-
BlipITMWrapper: blip_itm_wrapper_type_to_module_mapping,
4947
MLPModel: mlp_type_to_module_mapping,
5048
MLPForClassification: mlp_classifier_type_to_module_mapping,
5149
GRUModel: gru_type_to_module_mapping,
@@ -59,17 +57,16 @@
5957
type_to_dimension_mapping = {
6058
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_dimension_mapping,
6159
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_dimension_mapping,
62-
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_dimension_mapping,
6360
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping,
6461
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping,
6562
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping,
6663
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_dimension_mapping,
6764
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_dimension_mapping,
6865
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_dimension_mapping,
66+
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_dimension_mapping,
67+
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_dimension_mapping,
6968
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
70-
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping,
7169
BlipWrapper: blip_wrapper_type_to_dimension_mapping,
72-
BlipITMWrapper: blip_itm_wrapper_type_to_dimension_mapping,
7370
MLPModel: mlp_type_to_dimension_mapping,
7471
MLPForClassification: mlp_classifier_type_to_dimension_mapping,
7572
GRUModel: gru_type_to_dimension_mapping,

pyvene/models/mistral/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
Each modeling file in this library is a mapping between
3+
abstract naming of intervention anchor points and actual
4+
model module defined in the huggingface library.
5+
6+
We also want to let the intervention library know how to
7+
config the dimensions of intervention based on model config
8+
defined in the huggingface library.
9+
"""
10+
11+
12+
import torch
13+
from ..constants import *
14+
15+
16+
mistral_type_to_module_mapping = {
17+
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
18+
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
19+
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
20+
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
21+
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
22+
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
23+
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
24+
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
25+
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
26+
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
27+
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
28+
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
29+
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
30+
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
31+
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
32+
}
33+
34+
35+
mistral_type_to_dimension_mapping = {
36+
"block_input": ("hidden_size",),
37+
"block_output": ("hidden_size",),
38+
"mlp_activation": ("intermediate_size",),
39+
"mlp_output": ("hidden_size",),
40+
"mlp_input": ("hidden_size",),
41+
"attention_value_output": ("hidden_size",),
42+
"head_attention_value_output": ("hidden_size/num_attention_heads",),
43+
"attention_output": ("hidden_size",),
44+
"attention_input": ("hidden_size",),
45+
"query_output": ("hidden_size",),
46+
"key_output": ("hidden_size",),
47+
"value_output": ("hidden_size",),
48+
"head_query_output": ("hidden_size/num_attention_heads",),
49+
"head_key_output": ("hidden_size/num_attention_heads",),
50+
"head_value_output": ("hidden_size/num_attention_heads",),
51+
}
52+
53+
54+
"""llama model with LM head"""
55+
mistral_lm_type_to_module_mapping = {}
56+
for k, v in mistral_type_to_module_mapping.items():
57+
mistral_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])
58+
59+
60+
mistral_lm_type_to_dimension_mapping = mistral_type_to_dimension_mapping
61+
62+
63+
def create_mistral(
64+
name="mistralai/Mistral-7B-v0.1", cache_dir=None
65+
):
66+
"""Creates a Mistral Causal LM model, config, and tokenizer from the given name and revision"""
67+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
68+
69+
config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
70+
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
71+
llama = AutoModelForCausalLM.from_pretrained(
72+
name,
73+
config=config,
74+
cache_dir=cache_dir,
75+
torch_dtype=torch.bfloat16, # save memory
76+
)
77+
print("loaded model")
78+
return config, tokenizer, llama

0 commit comments

Comments
 (0)