forked from Trustworthy-ML-Lab/ThinkEdit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_tl.py
More file actions
119 lines (97 loc) · 4.76 KB
/
extract_tl.py
File metadata and controls
119 lines (97 loc) · 4.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gc
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from utils import model_dict
from transformers.utils import logging
logging.set_verbosity_error()
from tqdm import tqdm
import pickle
import time
import json
import math
import re
np.random.seed(20)
torch.manual_seed(20)
torch.cuda.manual_seed_all(20)
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="deepseek_llama3")
parser.add_argument("--control", type=str, default="attn", choices=["attn", "mlp"])
args = parser.parse_args()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# -- load model and tokenizer --
model_path = model_dict[args.model]
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
model.generation_config.do_sample = True
tokenizer.pad_token = tokenizer.eos_token
def extract_tl_dir(examples):
# -- attach attn hook to retrieve residuals --
if args.control == "attn":
print("attaching attn hook")
attn_outputs = []
def capture_residual_hook():
def hook_fn(module, input, output):
attn_outputs.append(input[0].detach())
return hook_fn
for layer in model.model.layers:
if args.control == "attn":
layer.post_attention_layernorm.register_forward_hook(capture_residual_hook())
embeddings = []
for example in tqdm(examples):
message = [{"role": "user", "content": example['question']}, {"role": "assistant", "content": ""}]
question = tokenizer.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
)
start = len(tokenizer(question).input_ids)
message = [{"role": "user", "content": example['question']}, {"role": "assistant", "content": example['thinking']}]
cot = tokenizer.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
)
end = len(tokenizer(cot).input_ids)
toks = tokenizer(cot, return_tensors="pt")
# toks = tokenizer(f"<|User|>{example['question']}<|Assistant|>").input_ids
# start = len(toks)
# toks = tokenizer(f"<|User|>{example['question']}<|Assistant|>{example['thinking']}").input_ids
# end = len(toks)
# toks = tokenizer(f"<|User|>{example['question']}<|Assistant|>{example['thinking']}", return_tensors="pt")
with torch.no_grad():
if args.control == "attn":
_ = model(input_ids=toks['input_ids'].to(device), attention_mask=toks['attention_mask'].to(device))
embeddings.append(torch.stack(attn_outputs, dim=0)[:, :, start-1:end-1, :].mean(dim=2).cpu())
attn_outputs = []
elif args.control == "mlp":
residual_outputs = model(input_ids=toks['input_ids'].to(device), attention_mask=toks['attention_mask'].to(device), output_hidden_states=True).hidden_states[1:]
embeddings.append(torch.stack(residual_outputs, dim=0)[:, :, start-1:end-1, :].mean(dim=2).cpu())
return torch.stack(embeddings, dim=0).mean(dim=0)
# Load JSON file with response data
json_file_path = f"responses/{args.model}_gsm8k.json"
with open(json_file_path, 'r') as f:
responses_data = json.load(f)
# Filter examples based on thinking length
valid_responses = [ex for ex in responses_data if ex['thinking_length'] != -1]
valid_lengths = [ex['thinking_length'] for ex in valid_responses]
tenth_percentile = np.percentile(valid_lengths, 10)
ninetieth_percentile = np.percentile(valid_lengths, 90)
long_thinking_examples = [ex for ex in valid_responses if ex['thinking_length'] > ninetieth_percentile]
short_thinking_examples = [ex for ex in valid_responses if ex['thinking_length'] < tenth_percentile]
# -- long examples --
print("number of long examples: ",len(long_thinking_examples))
mean_embedding_long = extract_tl_dir(long_thinking_examples)
# -- short examples --
print("number of short examples: ",len(short_thinking_examples))
mean_embedding_short = extract_tl_dir(short_thinking_examples)
# -- save embeddings --
thinking_length_direction = mean_embedding_long - mean_embedding_short
os.makedirs("directions", exist_ok=True)
torch.save(thinking_length_direction, f"directions/{args.model}_thinking_length_direction_gsm8k_{args.control}.pt")