Skip to content

[Bug Report] Error when patching key or value heads #980

@nikolaystanishev

Description

@nikolaystanishev

Describe the bug
There is an exception when patching of key or value heads is applied for models where the number of attention heads (n_heads) is different from the one of key and value heads (n_key_value_heads). Example for such model is meta-llama/Llama-3.2-1B. Affected methods are:

  • patching.get_act_patch_attn_head_k_all_pos
  • patching.get_act_patch_attn_head_v_all_pos
  • patching.get_act_patch_attn_head_all_pos_every
  • patching.get_act_patch_attn_head_k_by_pos
  • patching.get_act_patch_attn_head_v_by_pos
  • patching.get_act_patch_attn_head_by_pos_every

Code example
Reproduction code

import torch
import transformer_lens.patching as patching
from transformer_lens import HookedTransformer

model_name = 'meta-llama/Llama-3.2-1B'
device = torch.device('cuda:0')
model = HookedTransformer.from_pretrained(model_name, device=device)

prompts = ['John']
corrupted_prompts = ['Mary']
answers = [(' John', ' Mary')]

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=device)

def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape) == 3:
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_tokens = model.to_tokens(prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)
_, clean_cache = model.run_with_cache(clean_tokens)

patching.get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, get_logit_diff)

Resulting stack trace

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[51], line 2
      1 # reproduction of the transformer_lens bug
----> 2 every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, diff_metric)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/patching.py:218, in generic_activation_patch(model, corrupted_tokens, clean_cache, patching_metric, patch_setter, activation_name, index_axis_names, index_df, return_index_df)
    211 current_hook = partial(
    212     patching_hook,
    213     index=index,
    214     clean_activation=clean_cache[current_activation_name],
    215 )
    217 # Run the model with the patching hook and get the logits!
--> 218 patched_logits = model.run_with_hooks(
    219     corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)]
    220 )
    222 # Calculate the patching metric and store
    223 if flattened_output:

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/hook_points.py:456, in HookedRootModule.run_with_hooks(self, fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts, *model_args, **model_kwargs)
    451     logging.warning(
    452         "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
    453     )
    455 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
--> 456     return hooked_model.forward(*model_args, **model_kwargs)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/HookedTransformer.py:612, in HookedTransformer.forward(self, input, return_type, loss_per_token, prepend_bos, padding_side, start_at_layer, tokens, shortformer_pos_embed, attention_mask, stop_at_layer, past_kv_cache)
    607     if shortformer_pos_embed is not None:
    608         shortformer_pos_embed = shortformer_pos_embed.to(
    609             devices.get_device_for_block_index(i, self.cfg)
    610         )
--> 612     residual = block(
    613         residual,
    614         # Cache contains a list of HookedTransformerKeyValueCache objects, one for each
    615         # block
    616         past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
    617         shortformer_pos_embed=shortformer_pos_embed,
    618         attention_mask=attention_mask,
    619     )  # [batch, pos, d_model]
    621 if stop_at_layer is not None:
    622     # When we stop at an early layer, we end here rather than doing further computation
    623     return residual

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/components/transformer_block.py:160, in TransformerBlock.forward(self, resid_pre, shortformer_pos_embed, past_kv_cache_entry, attention_mask)
    153     key_input = attn_in
    154     value_input = attn_in
    156 attn_out = (
    157     # hook the residual stream states that are used to calculate the
    158     # queries, keys and values, independently.
    159     # Then take the layer norm of these inputs, and pass these to the attention module.
--> 160     self.attn(
    161         query_input=self.ln1(query_input)
    162         + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
    163         key_input=self.ln1(key_input)
    164         + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
    165         value_input=self.ln1(value_input),
    166         past_kv_cache_entry=past_kv_cache_entry,
    167         attention_mask=attention_mask,
    168     )
    169 )  # [batch, pos, d_model]
    170 if self.cfg.use_normalization_before_and_after:
    171     # If we use LayerNorm both before and after, then apply the second LN after the layer
    172     # and before the hook. We do it before the hook so hook_attn_out captures "that which
    173     # is added to the residual stream"
    174     attn_out = self.ln1_post(attn_out)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/components/abstract_attention.py:196, in AbstractAttention.forward(self, query_input, key_input, value_input, past_kv_cache_entry, additive_attention_mask, attention_mask, position_bias)
    168 def forward(
    169     self,
    170     query_input: Union[
   (...)
    187     position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
    188 ) -> Float[torch.Tensor, "batch pos d_model"]:
    189     """
    190     shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
    191     past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
    192     additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
    193     attention_mask is the attention mask for padded tokens. Defaults to None.
    194     """
--> 196     q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
    198     if past_kv_cache_entry is not None:
    199         # Appends the new keys and values to the cached values, and automatically updates the cache
    200         kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/components/grouped_query_attention.py:130, in GroupedQueryAttention.calculate_qkv_matrices(self, query_input, key_input, value_input)
    120 attn_fn = (
    121     complex_attn_linear
    122     if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
    123     else simple_attn_linear
    124 )
    126 q = self.hook_q(
    127     attn_fn(query_input, self.W_Q, self.b_Q)
    128 )  # [batch, pos, head_index, d_head]
--> 130 k = self.hook_k(
    131     attn_fn(key_input, self.W_K, self.b_K)
    132     if self.cfg.ungroup_grouped_query_attention
    133     else attn_fn(key_input, self._W_K, self._b_K)
    134 )  # [batch, pos, head_index, d_head]
    135 v = self.hook_v(
    136     attn_fn(value_input, self.W_V, self.b_V)
    137     if self.cfg.ungroup_grouped_query_attention
    138     else attn_fn(value_input, self._W_V, self._b_V)
    139 )  # [batch, pos, head_index, d_head]
    140 return q, k, v

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1574, in Module._call_impl(self, *args, **kwargs)
   1572     hook_result = hook(self, args, kwargs, result)
   1573 else:
-> 1574     hook_result = hook(self, args, result)
   1576 if hook_result is not None:
   1577     result = hook_result

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/hook_points.py:109, in HookPoint.add_hook.<locals>.full_hook(module, module_input, module_output)
    105 if (
    106     dir == "bwd"
    107 ):  # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
    108     module_output = module_output[0]
--> 109 return hook(module_output, hook=self)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/patching.py:201, in generic_activation_patch.<locals>.patching_hook(corrupted_activation, hook, index, clean_activation)
    200 def patching_hook(corrupted_activation, hook, index, clean_activation):
--> 201     return patch_setter(corrupted_activation, index, clean_activation)

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/patching.py:276, in layer_head_vector_patch_setter(corrupted_activation, index, clean_activation)
    274 assert len(index) == 2
    275 layer, head_index = index
--> 276 corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index]
    278 return corrupted_activation

File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/utils/_device.py:77, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
     75 if func in _device_constructors() and kwargs.get('device') is None:
     76     kwargs['device'] = self.device
---> 77 return func(*args, **kwargs)

IndexError: index 8 is out of bounds for dimension 2 with size 8

System Info
Describe the characteristic of your environment:

  • transformer_lens was installed from pip
  • I have reproduce it both on Mac and on Google Colab
  • Python version 3.9

Additional context
The problem is that the used number of heads inside patching.generic_activation_patch is n_heads for key and value heads instead of n_key_value_heads. And the other issue inside the methods (patching.get_act_patch_attn_head_all_pos_every and patching.get_act_patch_attn_head_by_pos_every) stacking results from the attention heads patching is that the results for the key and value head patching are going to be in different dimention.

Resolve of this bug is proposed in this pull request - #981.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions