Skip to content

Conversation

@hartigel
Copy link

Fix incorrect logit lens implementation in accumulated_resid

Fixes #[issue_number]

What was wrong

The accumulated_resid(apply_ln=True) method is documented for logit lens analysis but was implementing it incorrectly:

The bug:

  • Used a single cached normalization scale for all intermediate layers instead of recomputing statistics per layer
  • Applied the cached scale from ln_final.hook_scale (computed during the original forward pass) to all intermediate states
  • Did not apply layer norm affine parameters (weights/biases) - though this could be handled with fold_ln=True, this workaround was not documented and still didn't fix the core problem of using cached statistics

Why this matters:

The cached normalization scale acts like a temperature parameter in softmax. Since the residual stream norm grows exponentially through the network, early-layer activations divided by the final layer's large cached scale become artificially small. After unembedding and softmax, this creates near-uniform distributions—not because the model is uncertain, but because of mismatched normalization statistics.

The logic of using cached normalization scales comes from Direct Logit Attribution (DLA), where it's correct and reasonable: you want a faithful decomposition showing how each component contributes to the final output through the lens of the final layer's normalization frame. However, for logit lens analysis, this is not what we want. Logit lens asks "what would the model predict if we stopped processing at layer L?" which requires simulating what would happen if that intermediate state were decoded immediately—meaning we must recompute fresh normalization statistics.

Additionally, accumulated_resid doesn't attribute to specific components (individual heads or MLP layers) but to the complete accumulated residual stream up to layer L. The previous cached-scale approach answered a confusing hybrid question: "what is the contribution of all components up to layer l, viewed through layer l's (or the final layer's) normalization frame?" This is neither standard logit lens (which asks about intermediate predictions) nor standard DLA (which attributes to individual components).

The confusion was compounded by the layer parameter behavior: when layer=L < n_layers, it used layer L's input normalization (ln1 or ln2), not the final layer norm. This makes no sense for logit lens, where you always want to decode through the final layer norm. The return_labels parameter further hints at the function's true intention: to map intermediate states to vocabulary space for analysis across layers—which requires consistent use of the final layer norm.

Concretely:

  • Early layers appeared artificially uniform: Layer 0 showed entropy of ~15.6 bits (near-maximum for 50K vocab) when it should show ~0.5 bits (peaked on the input token)
  • Wrong statistics were used: The cached scale was computed on the final layer's activations during the forward pass, not on the intermediate states being normalized
  • Contradicts standard logit lens: The literature (nostalgebraist, Tuned Lens, Patchscopes) defines logit lens as recomputing normalization fresh to answer "what would the model predict if we stopped here?"
  • Doesn't match documentation: The method's docstring states it applies "the final layer norm" for logit lens, implying proper normalization

The model could encode identical concepts at intermediate and final layers—just at different scales—but the cached-scale approach would report completely different entropy and probabilities, even though the representations are functionally equivalent (all subsequent layers normalize their inputs anyway).

Example (GPT-2 XL, "The capital of France is"):

  • Layer 0 before: " is" at 0.002% (near-uniform distribution)
  • Layer 0 after: " is" at 99.997% (correct: uncontextualized input embedding)
  • Note: Top-1 predictions remain unchanged (normalization preserves direction), but probability distributions and entropy measurements now reflect actual model behavior rather than normalization artifacts

What changed

Code (activation_cache.py)

Changed one line in accumulated_resid():

# Before
if apply_ln:
    components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice, mlp_input=mlp_input)

# After  
if apply_ln:
    components = self.model.ln_final(components)

This now:

  • Recomputes normalization statistics (mean/variance or RMS) for each intermediate layer state
  • Applies the final layer norm's affine parameters correctly
  • Works via PyTorch broadcasting over the [num_layers, batch, pos, d_model] stack
  • Implements logit lens as defined in the literature
  • Always uses the final layer norm, regardless of the layer parameter value

Documentation (accumulated_resid docstring)

Fixed ambiguity about bias terms:

The docstring now explicitly documents two valid approaches for projecting to vocabulary space:

  1. With bias terms: normalized_resid @ model.W_U + model.b_U applies both W_U and b_U

    • Works correctly with both fold_ln=True and fold_ln=False
    • When fold_ln=False: layer norm bias applied via ln_final, then unembedding bias via b_U
    • When fold_ln=True: layer norm bias is folded into b_U, so adding b_U applies both biases together
  2. Without bias terms: normalized_resid @ model.W_U only

    • Should use fold_ln=True when loading model
    • With fold_ln=True: layer norm has no bias parameter (folded into b_U), and you skip b_U, so no bias terms are applied
    • With fold_ln=False: layer norm bias would still be applied via ln_final (usually undesired when intentionally excluding biases)

Previously, the docstring only mentioned "multiply by the unembedding matrix" without clarifying the bias term handling or its interaction with fold_ln. This is important because the layer norm bias and unembedding bias are applied at different points in the computation, and fold_ln determines whether they're kept separate or combined.

Other documentation improvements:

  • Clarified that apply_ln recomputes normalization statistics to transform activations into the format expected by unembedding
  • Updated code example to demonstrate the no-bias approach with fold_ln=True, with a commented alternative showing how to add bias terms
  • Shows both approaches clearly in the example code

Tests (test_activation_cache.py)

Updated test_accumulated_resid_with_apply_ln to verify correct behavior:

Before:

accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
ref_scaled_residual_stack = cache.apply_ln_to_stack(
    accumulated_residual, layer=-1, pos_slice=-1
)
scaled_residual_stack = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, apply_ln=True
)
assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all()

After:

accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
ref_scaled_residual_stack = model.ln_final(accumulated_residual)
scaled_residual_stack = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, apply_ln=True
)
assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all()

The key change: instead of comparing against apply_ln_to_stack (the buggy cached-scale behavior), we now compare against direct application of model.ln_final() (the correct fresh-normalization behavior). This ensures apply_ln=True produces the same result as manually applying the final layer norm.

Breaking change

This is a breaking change that fixes incorrect behavior.

Users calling accumulated_resid(apply_ln=True) will get different numerical outputs:

  • Probability distributions now reflect actual model beliefs, not normalization artifacts
  • Entropy and KL divergence measurements are now meaningful
  • Top-1 predictions and rankings remain unchanged (direction preserved)

The previous behavior was producing measurement artifacts. Users relying on those outputs were getting misleading interpretability results. Code will continue to run without errors, but the numerical results will be different (and correct).

Why this doesn't affect DLA

Direct Logit Attribution (DLA) decomposes outputs into individual component contributions (specific heads, MLP layers). For that, use decompose_resid() to get individual components, not accumulated_resid() which returns cumulative sums of all components up to each layer.

The cached normalization logic that was used before makes sense in DLA contexts where you want to see component contributions through a consistent normalization frame. But accumulated_resid is designed for logit lens analysis (mapping intermediate states to predictions), not DLA decomposition. This fix only affects the logit lens use case (apply_ln=True), not DLA decomposition logic elsewhere in the codebase.

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Note: While this is technically a breaking change (numerical outputs differ), existing code will continue to run without errors. The change only affects the numerical values returned, which are now correct rather than artifacts. Users won't experience crashes or API changes, just different (and more accurate) results.

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes (one existing test that directly tested the before implemented behavior does not pass, however this one is altered to reflect the changes)
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility (I do have rewritten the one test test_accumulated_resid_with_apply_ln, none besides that)

Notes:

  • Updated one existing test (test_accumulated_resid_with_apply_ln) to verify correct behavior instead of buggy behavior
  • The interface itself (accumulated_resid signature and parameters) remains unchanged, so backward compatibility is maintained at the API level
  • Only the numerical outputs differ (which is the intended fix)
  • Notebook tests do not pass locally (but this appears to be a pre-existing environment issue unrelated to these changes)

hartig added 3 commits October 10, 2025 14:28
…d_resid to match the standard definition in literature and the expected and defined behavior as per the documentation in the docstring and in the docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant