1- # surprisal_intervention_scorer.py
21import functools
32import random
43import copy
98import torch .nn .functional as F
109from transformers import AutoTokenizer
1110
12- # Assuming 'delphi' is your project structure.
13- # If not, you may need to adjust these relative imports.
1411from ..scorer import Scorer , ScorerResult
1512from ...latents import LatentRecord , ActivatingExample
1613
@@ -75,11 +72,9 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs):
7572 if len (self .hookpoints ):
7673 self .hookpoint_str = self .hookpoints [0 ]
7774
78- # Ensure tokenizer is available
7975 if hasattr (subject_model , "tokenizer" ):
8076 self .tokenizer = subject_model .tokenizer
8177 else :
82- # Fallback to a standard tokenizer if not attached to the model
8378 self .tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
8479
8580 if self .tokenizer .pad_token is None :
@@ -113,7 +108,6 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
113108 """
114109 parts = hookpoint_str .split ('.' )
115110
116- # 1. Validate the string format.
117111 is_valid_format = (
118112 len (parts ) == 3 and
119113 parts [0 ] in ['layers' , 'h' ] and
@@ -122,137 +116,75 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
122116 )
123117
124118 if not is_valid_format :
125- # Fallback for simple block types at the top level, e.g. 'embed_in'
126119 if len (parts ) == 1 and hasattr (model , hookpoint_str ):
127120 return getattr (model , hookpoint_str )
128121 raise ValueError (f"Hookpoint string '{ hookpoint_str } ' is not in a recognized format like 'layers.6.mlp'." )
129- # --- End of changes ---
130122
131- # 2. Heuristically find the model prefix.
123+ #Heuristically find the model prefix.
132124 prefix = None
133125 for p in ["gpt_neox" , "transformer" , "model" ]:
134126 if hasattr (model , p ):
135127 candidate_body = getattr (model , p )
136- # Use parts[0] to get the layer block name ('layers' or 'h')
137128 if hasattr (candidate_body , parts [0 ]):
138129 prefix = p
139130 break
140131
141132 full_path = f"{ prefix } .{ hookpoint_str } " if prefix else hookpoint_str
142133
143- # 3. Use the simple path finder to get the module.
144134 try :
145135 return self ._find_layer (model , full_path )
146136 except AttributeError as e :
147137 raise AttributeError (f"Could not resolve path '{ full_path } '. Model structure might be unexpected. Original error: { e } " )
148138
149-
150-
151-
152- # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]:
153- # """Ensures examples are in a consistent format: a list of dictionaries with 'str_tokens'."""
154- # sanitized = []
155- # for ex in examples:
156- # if isinstance(ex, dict) and "str_tokens" in ex:
157- # sanitized.append(ex)
158- # elif hasattr(ex, "str_tokens"):
159- # sanitized.append({"str_tokens": [str(t) for t in ex.str_tokens]})
160- # elif isinstance(ex, str):
161- # sanitized.append({"str_tokens": [ex]})
162- # elif isinstance(ex, (list, tuple)):
163- # sanitized.append({"str_tokens": [str(t) for t in ex]})
164- # else:
165- # sanitized.append({"str_tokens": [str(ex)]})
166- # return sanitized
167-
168139
169140 def _sanitize_examples (self , examples : List [Any ]) -> List [Dict [str , Any ]]:
141+ """
142+ Function used for formatting results to run smoothly in the delphi pipeline
143+ """
170144 sanitized = []
171145 for ex in examples :
172- # --- NEW, MORE ROBUST LOGIC ---
173- # 1. Prioritize handling objects that have the data we need (like ActivatingExample)
174146 if hasattr (ex , 'str_tokens' ) and ex .str_tokens is not None :
175- # This correctly handles ActivatingExample objects and similar structures.
176- # It extracts the string tokens instead of converting the whole object to a string.
177147 sanitized .append ({'str_tokens' : ex .str_tokens })
178148
179- # 2. Handle cases where the item is already a correct dictionary
180149 elif isinstance (ex , dict ) and "str_tokens" in ex :
181150 sanitized .append (ex )
182151
183- # 3. Handle plain strings
184152 elif isinstance (ex , str ):
185153 sanitized .append ({"str_tokens" : [ex ]})
186154
187- # 4. Handle lists/tuples of strings as a fallback
188155 elif isinstance (ex , (list , tuple )):
189156 sanitized .append ({"str_tokens" : [str (t ) for t in ex ]})
190157
191- # 5. Handle any other unexpected type as a last resort
192158 else :
193159 sanitized .append ({"str_tokens" : [str (ex )]})
194160
195161 return sanitized
196162
197163
198- # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]:
199-
200- # sanitized = []
201- # for i, ex in enumerate(examples):
202-
203-
204- # if isinstance(ex, dict) and "str_tokens" in ex:
205- # sanitized.append(ex)
206-
207-
208- # elif isinstance(ex, str):
209- # # This is the key conversion
210- # converted_ex = {"str_tokens": [ex]}
211- # sanitized.append(converted_ex)
212-
213-
214- # elif isinstance(ex, (list, tuple)):
215- # converted_ex = {"str_tokens": [str(t) for t in ex]}
216- # sanitized.append(converted_ex)
217-
218- # else:
219- # converted_ex = {"str_tokens": [str(ex)]}
220- # sanitized.append(converted_ex)
221-
222- # print("fin this")
223- # return sanitized
224164
225165 async def __call__ (self , record : LatentRecord ) -> ScorerResult :
226- # --- MODIFICATION START ---
227- # 1. Create a deep copy to work on, ensuring we don't interfere
228- # with other parts of the pipeline that might use the original record.
166+
229167 record_copy = copy .deepcopy (record )
230168
231- # 2. Read the raw examples from our copy.
232169 raw_examples = getattr (record_copy , "test" , []) or []
233170
234171 if not raw_examples :
235172 result = SurprisalInterventionResult (score = 0.0 , avg_kl = 0.0 , explanation = record_copy .explanation )
236- # Return the result with the original record since no changes were made.
237173 return ScorerResult (record = record , score = result )
238174
239- # 3. Sanitize the examples.
240175 examples = self ._sanitize_examples (raw_examples )
241176
242- # 4. Overwrite the attributes on the copy with the clean data.
243177 record_copy .test = examples
244178 record_copy .examples = examples
245179 record_copy .train = examples
246180
247- # Now, use the sanitized 'examples' and the 'record_copy' for all subsequent operations.
248181 prompts = ["" .join (ex ["str_tokens" ]) for ex in examples [:self .num_prompts ]]
249182
250183 total_diff = 0.0
251184 total_kl = 0.0
252185 n = 0
253186
254187 for prompt in prompts :
255- # Pass the clean record_copy to the generation methods.
256188 clean_text , clean_logp_dist = await self ._generate_with_and_without_intervention (prompt , record_copy , intervene = False )
257189 int_text , int_logp_dist = await self ._generate_with_and_without_intervention (prompt , record_copy , intervene = True )
258190
@@ -274,7 +206,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
274206 for ex in examples [:self .num_prompts ]:
275207 final_output_list .append ({
276208 "str_tokens" : ex ["str_tokens" ],
277- # Add the final scores. These will be duplicated for each example.
278209 "final_score" : final_score ,
279210 "avg_kl_divergence" : avg_kl ,
280211 # Add placeholder keys that the parser expects, with default values.
@@ -312,14 +243,12 @@ async def _generate_with_and_without_intervention(
312243 if hookpoint_str is None :
313244 raise ValueError ("No hookpoint string specified for intervention." )
314245
315- # Resolve the string into the actual layer module.
316246 layer_to_hook = self ._resolve_hookpoint (self .subject_model , hookpoint_str )
317247
318248 direction = self ._get_intervention_direction (record ).to (device )
319249 direction = direction .unsqueeze (0 ).unsqueeze (0 ) # Shape for broadcasting: [1, 1, D]
320250
321251 def hook_fn (module , inp , out ):
322- # Gracefully handle both tuple and tensor outputs
323252 hidden_states = out [0 ] if isinstance (out , tuple ) else out
324253
325254 # Apply intervention to the last token's hidden state
@@ -423,7 +352,6 @@ def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tenso
423352 def capture_hook (module , inp , out ):
424353 hidden_states = out [0 ] if isinstance (out , tuple ) else out
425354
426- # Now, hidden_states is guaranteed to be the 3D activation tensor
427355 captured_activations .append (hidden_states [:, - 1 , :].detach ().cpu ())
428356
429357 hookpoint_str = self .hookpoint_str or getattr (record , "hookpoint" , None )
0 commit comments