Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
cf7ddca
Surprisal intervention and config
saireddythfc Aug 28, 2025
0ad4424
Add metrics for surprisal_intervention
saireddythfc Aug 28, 2025
aa12cf2
Code cleaning
saireddythfc Aug 29, 2025
bbf915a
Fix results
saireddythfc Aug 29, 2025
90e8a34
Remove output-based intervention
saireddythfc Aug 29, 2025
9fca7db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2025
d3d269e
Fix pre-commit
saireddythfc Aug 29, 2025
126b978
Fix pre-commit
saireddythfc Aug 29, 2025
9e878ca
Fix conflicts
saireddythfc Aug 29, 2025
8e893e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2025
2a546e0
Fix EOFs
saireddythfc Aug 29, 2025
73bb341
Fix EOFs
saireddythfc Aug 29, 2025
6a6368c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2025
88f1b35
Tuned Kl divergence
saireddythfc Sep 10, 2025
4ffc891
Tuned KL divergence
saireddythfc Sep 10, 2025
6db120f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
1a6fa0c
Pre-commit clears
saireddythfc Sep 11, 2025
572da19
Merge branch 'main' of https://github.com/saireddythfc/delphi
saireddythfc Sep 11, 2025
6e18bba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2025
ba6533b
Fix intervention point
saireddythfc Nov 17, 2025
5161d9a
Merge remote-tracking branch 'upstream/main'
saireddythfc Nov 17, 2025
68d6c63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2025
101257e
Line-spacing fix
saireddythfc Nov 17, 2025
8c65dbe
Fix line-spacing issues
saireddythfc Nov 17, 2025
c6b901a
Merge branch 'main' of https://github.com/saireddythfc/delphi
saireddythfc Nov 17, 2025
cabb151
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from simple_parsing import ArgumentParser
from torch import Tensor
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PreTrainedModel,
Expand All @@ -27,7 +27,12 @@
from delphi.latents.neighbours import NeighbourCalculator
from delphi.log.result_analysis import log_results
from delphi.pipeline import Pipe, Pipeline, process_wrapper
from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator
from delphi.scorers import (
DetectionScorer,
FuzzingScorer,
OpenAISimulator,
SurprisalInterventionScorer,
)
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders
from delphi.utils import assert_type, load_tokenized_data

Expand All @@ -40,7 +45,7 @@ def load_artifacts(run_cfg: RunConfig):
else:
dtype = "auto"

model = AutoModel.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
run_cfg.model,
device_map={"": "cuda"},
quantization_config=(
Expand Down Expand Up @@ -118,6 +123,8 @@ async def process_cache(
hookpoints: list[str],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
latent_range: Tensor | None,
model,
hookpoint_to_sparse_encode,
):
"""
Converts SAE latent activations in on-disk cache in the `latents_path` directory
Expand Down Expand Up @@ -219,6 +226,12 @@ def none_postprocessor(result):
)
)

def custom_serializer(obj):
"""A custom serializer for orjson to handle specific types."""
if isinstance(obj, Tensor):
return obj.tolist()
raise TypeError

# Builds the record from result returned by the pipeline
def scorer_preprocess(result):
if isinstance(result, list):
Expand All @@ -230,11 +243,18 @@ def scorer_preprocess(result):
return record

# Saves the score to a file
def scorer_postprocess(result, score_dir):
# In your __main__.py file

def scorer_postprocess(result, score_dir, scorer_name=None):
if isinstance(result, list):
if not result:
return
result = result[0]

safe_latent_name = str(result.record.latent).replace("/", "--")

with open(score_dir / f"{safe_latent_name}.txt", "wb") as f:
f.write(orjson.dumps(result.score))
f.write(orjson.dumps(result.score, default=custom_serializer))

scorers = []
for scorer_name in run_cfg.scorers:
Expand Down Expand Up @@ -265,6 +285,16 @@ def scorer_postprocess(result, score_dir):
verbose=run_cfg.verbose,
log_prob=run_cfg.log_probs,
)

elif scorer_name == "surprisal_intervention":
scorer = SurprisalInterventionScorer(
model,
hookpoint_to_sparse_encode,
hookpoints=run_cfg.hookpoints,
n_examples_shown=run_cfg.num_examples_per_scorer_prompt,
verbose=run_cfg.verbose,
log_prob=run_cfg.log_probs,
)
else:
raise ValueError(f"Scorer {scorer_name} not supported")

Expand Down Expand Up @@ -404,6 +434,8 @@ async def run(
hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg)
tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)

model.tokenizer = tokenizer

nrh = assert_type(
dict,
non_redundant_hookpoints(
Expand All @@ -420,7 +452,6 @@ async def run(
transcode,
)

del model, hookpoint_to_sparse_encode
if run_cfg.constructor_cfg.non_activating_source == "neighbours":
nrh = assert_type(
list,
Expand Down Expand Up @@ -453,8 +484,12 @@ async def run(
nrh,
tokenizer,
latent_range,
model,
hookpoint_to_sparse_encode,
)

del model, hookpoint_to_sparse_encode

if run_cfg.verbose:
log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers)

Expand Down
10 changes: 3 additions & 7 deletions delphi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,14 @@ class RunConfig(Serializable):
the default single token explainer, and 'none' for no explanation generation."""

scorers: list[str] = list_field(
choices=[
"fuzz",
"detection",
"simulation",
],
choices=["fuzz", "detection", "simulation", "surprisal_intervention"],
default=[
"fuzz",
"detection",
],
)
"""Scorer methods to score latent explanations. Options are 'fuzz', 'detection', and
'simulation'."""
"""Scorer methods to score latent explanations. Options are 'fuzz', 'detection',
'simulation' and 'surprisal_intervention'."""
fuzz_type: Literal["default", "active"] = "default"
"""Type of fuzzing to use for the fuzz scorer. Default uses non-activating
examples and highlights n_incorrect tokens. Active uses activating examples
Expand Down
7 changes: 7 additions & 0 deletions delphi/latents/latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ class LatentRecord:
"""Frequency of the latent. Number of activations in a context per total
number of contexts."""

@property
def feature_id(self) -> int:
"""
Returns the unique feature index for this latent.
"""
return self.latent.latent_index

@property
def max_activation(self) -> float:
"""
Expand Down
Loading