Skip to content

Commit aa12cf2

Browse files
committed
Code cleaning
1 parent 0ad4424 commit aa12cf2

File tree

4 files changed

+9
-102
lines changed

4 files changed

+9
-102
lines changed

delphi/__main__.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from delphi.latents.neighbours import NeighbourCalculator
3131
from delphi.log.result_analysis import log_results
3232
from delphi.pipeline import Pipe, Pipeline, process_wrapper
33-
from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, InterventionScorer, LogProbInterventionScorer, SurprisalInterventionScorer
33+
from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, SurprisalInterventionScorer
3434
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders
3535
from delphi.utils import assert_type, load_tokenized_data
3636

@@ -252,8 +252,6 @@ def scorer_postprocess(result, score_dir, scorer_name=None):
252252
safe_latent_name = str(result.record.latent).replace("/", "--")
253253

254254
with open(score_dir / f"{safe_latent_name}.txt", "wb") as f:
255-
# This line now works universally. For other scorers, it saves their simple
256-
# score. For surprisal_intervention, it saves the rich 'final_payload'.
257255
f.write(orjson.dumps(result.score, default=custom_serializer))
258256

259257

@@ -278,20 +276,7 @@ def scorer_postprocess(result, score_dir, scorer_name=None):
278276
verbose=run_cfg.verbose,
279277
log_prob=run_cfg.log_probs,
280278
)
281-
elif scorer_name == "intervention":
282-
scorer = InterventionScorer(
283-
llm_client,
284-
n_examples_shown=run_cfg.num_examples_per_scorer_prompt,
285-
verbose=run_cfg.verbose,
286-
log_prob=run_cfg.log_probs,
287-
)
288-
elif scorer_name == "logprob_intervention":
289-
scorer = LogProbInterventionScorer(
290-
llm_client,
291-
n_examples_shown=run_cfg.num_examples_per_scorer_prompt,
292-
verbose=run_cfg.verbose,
293-
log_prob=run_cfg.log_probs,
294-
)
279+
295280
elif scorer_name == "surprisal_intervention":
296281
scorer = SurprisalInterventionScorer(
297282
model,

delphi/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,6 @@ class RunConfig(Serializable):
152152
"fuzz",
153153
"detection",
154154
"simulation",
155-
"intervention",
156-
"logprob_intervention",
157155
"surprisal_intervention"
158156
],
159157
default=[
@@ -162,7 +160,7 @@ class RunConfig(Serializable):
162160
],
163161
)
164162
"""Scorer methods to score latent explanations. Options are 'fuzz', 'detection',
165-
'simulation' and 'intervention'."""
163+
'simulation' and 'surprisal_intervention'."""
166164

167165
name: str = ""
168166
"""The name of the run. Results are saved in a directory with this name."""

delphi/log/result_analysis.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def parse_score_file(path: Path) -> pd.DataFrame:
150150

151151
latent_idx = int(path.stem.split("latent")[-1])
152152

153-
# --- MODIFICATION 1: PARSE THE NEW METRICS ---
154153
# Updated to extract all possible keys safely using .get()
155154
return pd.DataFrame(
156155
[
@@ -254,11 +253,9 @@ def log_results(
254253
dead = sum((counts[m] == 0).sum().item() for m in modules)
255254
print(f"Number of dead features: {dead}")
256255

257-
# --- MODIFICATION 2: ADD CONDITIONAL REPORTING ---
258-
# Loop through all scorer types found in the data
256+
259257
for score_type in latent_df["score_type"].unique():
260258

261-
# Handle the new scorer with its specific metrics
262259
if score_type == 'surprisal_intervention':
263260
# Drop duplicates since score is per-latent, not per-example
264261
unique_latents = surprisal_df.drop_duplicates(subset=['module', 'latent_idx'])
@@ -269,7 +266,6 @@ def log_results(
269266
print(f"Average Normalized Score: {avg_score:.3f}")
270267
print(f"Average KL Divergence: {avg_kl:.3f}")
271268

272-
# Handle all other scorers with the original classification metrics
273269
else:
274270
if not classification_df.empty:
275271
score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0]

delphi/scorers/intervention/surprisal_intervention_scorer.py

Lines changed: 5 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# surprisal_intervention_scorer.py
21
import functools
32
import random
43
import copy
@@ -9,8 +8,6 @@
98
import torch.nn.functional as F
109
from transformers import AutoTokenizer
1110

12-
# Assuming 'delphi' is your project structure.
13-
# If not, you may need to adjust these relative imports.
1411
from ..scorer import Scorer, ScorerResult
1512
from ...latents import LatentRecord, ActivatingExample
1613

@@ -75,11 +72,9 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs):
7572
if len(self.hookpoints):
7673
self.hookpoint_str = self.hookpoints[0]
7774

78-
# Ensure tokenizer is available
7975
if hasattr(subject_model, "tokenizer"):
8076
self.tokenizer = subject_model.tokenizer
8177
else:
82-
# Fallback to a standard tokenizer if not attached to the model
8378
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
8479

8580
if self.tokenizer.pad_token is None:
@@ -113,7 +108,6 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
113108
"""
114109
parts = hookpoint_str.split('.')
115110

116-
# 1. Validate the string format.
117111
is_valid_format = (
118112
len(parts) == 3 and
119113
parts[0] in ['layers', 'h'] and
@@ -122,137 +116,75 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
122116
)
123117

124118
if not is_valid_format:
125-
# Fallback for simple block types at the top level, e.g. 'embed_in'
126119
if len(parts) == 1 and hasattr(model, hookpoint_str):
127120
return getattr(model, hookpoint_str)
128121
raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.")
129-
# --- End of changes ---
130122

131-
# 2. Heuristically find the model prefix.
123+
#Heuristically find the model prefix.
132124
prefix = None
133125
for p in ["gpt_neox", "transformer", "model"]:
134126
if hasattr(model, p):
135127
candidate_body = getattr(model, p)
136-
# Use parts[0] to get the layer block name ('layers' or 'h')
137128
if hasattr(candidate_body, parts[0]):
138129
prefix = p
139130
break
140131

141132
full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str
142133

143-
# 3. Use the simple path finder to get the module.
144134
try:
145135
return self._find_layer(model, full_path)
146136
except AttributeError as e:
147137
raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}")
148138

149-
150-
151-
152-
# def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]:
153-
# """Ensures examples are in a consistent format: a list of dictionaries with 'str_tokens'."""
154-
# sanitized = []
155-
# for ex in examples:
156-
# if isinstance(ex, dict) and "str_tokens" in ex:
157-
# sanitized.append(ex)
158-
# elif hasattr(ex, "str_tokens"):
159-
# sanitized.append({"str_tokens": [str(t) for t in ex.str_tokens]})
160-
# elif isinstance(ex, str):
161-
# sanitized.append({"str_tokens": [ex]})
162-
# elif isinstance(ex, (list, tuple)):
163-
# sanitized.append({"str_tokens": [str(t) for t in ex]})
164-
# else:
165-
# sanitized.append({"str_tokens": [str(ex)]})
166-
# return sanitized
167-
168139

169140
def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]:
141+
"""
142+
Function used for formatting results to run smoothly in the delphi pipeline
143+
"""
170144
sanitized = []
171145
for ex in examples:
172-
# --- NEW, MORE ROBUST LOGIC ---
173-
# 1. Prioritize handling objects that have the data we need (like ActivatingExample)
174146
if hasattr(ex, 'str_tokens') and ex.str_tokens is not None:
175-
# This correctly handles ActivatingExample objects and similar structures.
176-
# It extracts the string tokens instead of converting the whole object to a string.
177147
sanitized.append({'str_tokens': ex.str_tokens})
178148

179-
# 2. Handle cases where the item is already a correct dictionary
180149
elif isinstance(ex, dict) and "str_tokens" in ex:
181150
sanitized.append(ex)
182151

183-
# 3. Handle plain strings
184152
elif isinstance(ex, str):
185153
sanitized.append({"str_tokens": [ex]})
186154

187-
# 4. Handle lists/tuples of strings as a fallback
188155
elif isinstance(ex, (list, tuple)):
189156
sanitized.append({"str_tokens": [str(t) for t in ex]})
190157

191-
# 5. Handle any other unexpected type as a last resort
192158
else:
193159
sanitized.append({"str_tokens": [str(ex)]})
194160

195161
return sanitized
196162

197163

198-
# def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]:
199-
200-
# sanitized = []
201-
# for i, ex in enumerate(examples):
202-
203-
204-
# if isinstance(ex, dict) and "str_tokens" in ex:
205-
# sanitized.append(ex)
206-
207-
208-
# elif isinstance(ex, str):
209-
# # This is the key conversion
210-
# converted_ex = {"str_tokens": [ex]}
211-
# sanitized.append(converted_ex)
212-
213-
214-
# elif isinstance(ex, (list, tuple)):
215-
# converted_ex = {"str_tokens": [str(t) for t in ex]}
216-
# sanitized.append(converted_ex)
217-
218-
# else:
219-
# converted_ex = {"str_tokens": [str(ex)]}
220-
# sanitized.append(converted_ex)
221-
222-
# print("fin this")
223-
# return sanitized
224164

225165
async def __call__(self, record: LatentRecord) -> ScorerResult:
226-
# --- MODIFICATION START ---
227-
# 1. Create a deep copy to work on, ensuring we don't interfere
228-
# with other parts of the pipeline that might use the original record.
166+
229167
record_copy = copy.deepcopy(record)
230168

231-
# 2. Read the raw examples from our copy.
232169
raw_examples = getattr(record_copy, "test", []) or []
233170

234171
if not raw_examples:
235172
result = SurprisalInterventionResult(score=0.0, avg_kl=0.0, explanation=record_copy.explanation)
236-
# Return the result with the original record since no changes were made.
237173
return ScorerResult(record=record, score=result)
238174

239-
# 3. Sanitize the examples.
240175
examples = self._sanitize_examples(raw_examples)
241176

242-
# 4. Overwrite the attributes on the copy with the clean data.
243177
record_copy.test = examples
244178
record_copy.examples = examples
245179
record_copy.train = examples
246180

247-
# Now, use the sanitized 'examples' and the 'record_copy' for all subsequent operations.
248181
prompts = ["".join(ex["str_tokens"]) for ex in examples[:self.num_prompts]]
249182

250183
total_diff = 0.0
251184
total_kl = 0.0
252185
n = 0
253186

254187
for prompt in prompts:
255-
# Pass the clean record_copy to the generation methods.
256188
clean_text, clean_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=False)
257189
int_text, int_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=True)
258190

@@ -274,7 +206,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
274206
for ex in examples[:self.num_prompts]:
275207
final_output_list.append({
276208
"str_tokens": ex["str_tokens"],
277-
# Add the final scores. These will be duplicated for each example.
278209
"final_score": final_score,
279210
"avg_kl_divergence": avg_kl,
280211
# Add placeholder keys that the parser expects, with default values.
@@ -312,14 +243,12 @@ async def _generate_with_and_without_intervention(
312243
if hookpoint_str is None:
313244
raise ValueError("No hookpoint string specified for intervention.")
314245

315-
# Resolve the string into the actual layer module.
316246
layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str)
317247

318248
direction = self._get_intervention_direction(record).to(device)
319249
direction = direction.unsqueeze(0).unsqueeze(0) # Shape for broadcasting: [1, 1, D]
320250

321251
def hook_fn(module, inp, out):
322-
# Gracefully handle both tuple and tensor outputs
323252
hidden_states = out[0] if isinstance(out, tuple) else out
324253

325254
# Apply intervention to the last token's hidden state
@@ -423,7 +352,6 @@ def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tenso
423352
def capture_hook(module, inp, out):
424353
hidden_states = out[0] if isinstance(out, tuple) else out
425354

426-
# Now, hidden_states is guaranteed to be the 3D activation tensor
427355
captured_activations.append(hidden_states[:, -1, :].detach().cpu())
428356

429357
hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None)

0 commit comments

Comments
 (0)