-
Notifications
You must be signed in to change notification settings - Fork 468
Description
Describe the bug
There is an exception when patching of key or value heads is applied for models where the number of attention heads (n_heads) is different from the one of key and value heads (n_key_value_heads). Example for such model is meta-llama/Llama-3.2-1B. Affected methods are:
- patching.get_act_patch_attn_head_k_all_pos
- patching.get_act_patch_attn_head_v_all_pos
- patching.get_act_patch_attn_head_all_pos_every
- patching.get_act_patch_attn_head_k_by_pos
- patching.get_act_patch_attn_head_v_by_pos
- patching.get_act_patch_attn_head_by_pos_every
Code example
Reproduction code
import torch
import transformer_lens.patching as patching
from transformer_lens import HookedTransformer
model_name = 'meta-llama/Llama-3.2-1B'
device = torch.device('cuda:0')
model = HookedTransformer.from_pretrained(model_name, device=device)
prompts = ['John']
corrupted_prompts = ['Mary']
answers = [(' John', ' Mary')]
answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=device)
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
if len(logits.shape) == 3:
logits = logits[:, -1, :]
correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
return (correct_logits - incorrect_logits).mean()
clean_tokens = model.to_tokens(prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)
_, clean_cache = model.run_with_cache(clean_tokens)
patching.get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, get_logit_diff)
patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, get_logit_diff)Resulting stack trace
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[51], line 2
1 # reproduction of the transformer_lens bug
----> 2 every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, diff_metric)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/patching.py:218, in generic_activation_patch(model, corrupted_tokens, clean_cache, patching_metric, patch_setter, activation_name, index_axis_names, index_df, return_index_df)
211 current_hook = partial(
212 patching_hook,
213 index=index,
214 clean_activation=clean_cache[current_activation_name],
215 )
217 # Run the model with the patching hook and get the logits!
--> 218 patched_logits = model.run_with_hooks(
219 corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)]
220 )
222 # Calculate the patching metric and store
223 if flattened_output:
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/hook_points.py:456, in HookedRootModule.run_with_hooks(self, fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts, *model_args, **model_kwargs)
451 logging.warning(
452 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
453 )
455 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
--> 456 return hooked_model.forward(*model_args, **model_kwargs)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/HookedTransformer.py:612, in HookedTransformer.forward(self, input, return_type, loss_per_token, prepend_bos, padding_side, start_at_layer, tokens, shortformer_pos_embed, attention_mask, stop_at_layer, past_kv_cache)
607 if shortformer_pos_embed is not None:
608 shortformer_pos_embed = shortformer_pos_embed.to(
609 devices.get_device_for_block_index(i, self.cfg)
610 )
--> 612 residual = block(
613 residual,
614 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each
615 # block
616 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
617 shortformer_pos_embed=shortformer_pos_embed,
618 attention_mask=attention_mask,
619 ) # [batch, pos, d_model]
621 if stop_at_layer is not None:
622 # When we stop at an early layer, we end here rather than doing further computation
623 return residual
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/components/transformer_block.py:160, in TransformerBlock.forward(self, resid_pre, shortformer_pos_embed, past_kv_cache_entry, attention_mask)
153 key_input = attn_in
154 value_input = attn_in
156 attn_out = (
157 # hook the residual stream states that are used to calculate the
158 # queries, keys and values, independently.
159 # Then take the layer norm of these inputs, and pass these to the attention module.
--> 160 self.attn(
161 query_input=self.ln1(query_input)
162 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
163 key_input=self.ln1(key_input)
164 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
165 value_input=self.ln1(value_input),
166 past_kv_cache_entry=past_kv_cache_entry,
167 attention_mask=attention_mask,
168 )
169 ) # [batch, pos, d_model]
170 if self.cfg.use_normalization_before_and_after:
171 # If we use LayerNorm both before and after, then apply the second LN after the layer
172 # and before the hook. We do it before the hook so hook_attn_out captures "that which
173 # is added to the residual stream"
174 attn_out = self.ln1_post(attn_out)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/components/abstract_attention.py:196, in AbstractAttention.forward(self, query_input, key_input, value_input, past_kv_cache_entry, additive_attention_mask, attention_mask, position_bias)
168 def forward(
169 self,
170 query_input: Union[
(...)
187 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
188 ) -> Float[torch.Tensor, "batch pos d_model"]:
189 """
190 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
191 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
192 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
193 attention_mask is the attention mask for padded tokens. Defaults to None.
194 """
--> 196 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
198 if past_kv_cache_entry is not None:
199 # Appends the new keys and values to the cached values, and automatically updates the cache
200 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/components/grouped_query_attention.py:130, in GroupedQueryAttention.calculate_qkv_matrices(self, query_input, key_input, value_input)
120 attn_fn = (
121 complex_attn_linear
122 if self.cfg.use_split_qkv_input or self.cfg.use_attn_in
123 else simple_attn_linear
124 )
126 q = self.hook_q(
127 attn_fn(query_input, self.W_Q, self.b_Q)
128 ) # [batch, pos, head_index, d_head]
--> 130 k = self.hook_k(
131 attn_fn(key_input, self.W_K, self.b_K)
132 if self.cfg.ungroup_grouped_query_attention
133 else attn_fn(key_input, self._W_K, self._b_K)
134 ) # [batch, pos, head_index, d_head]
135 v = self.hook_v(
136 attn_fn(value_input, self.W_V, self.b_V)
137 if self.cfg.ungroup_grouped_query_attention
138 else attn_fn(value_input, self._W_V, self._b_V)
139 ) # [batch, pos, head_index, d_head]
140 return q, k, v
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/nn/modules/module.py:1574, in Module._call_impl(self, *args, **kwargs)
1572 hook_result = hook(self, args, kwargs, result)
1573 else:
-> 1574 hook_result = hook(self, args, result)
1576 if hook_result is not None:
1577 result = hook_result
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/hook_points.py:109, in HookPoint.add_hook.<locals>.full_hook(module, module_input, module_output)
105 if (
106 dir == "bwd"
107 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
108 module_output = module_output[0]
--> 109 return hook(module_output, hook=self)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/patching.py:201, in generic_activation_patch.<locals>.patching_hook(corrupted_activation, hook, index, clean_activation)
200 def patching_hook(corrupted_activation, hook, index, clean_activation):
--> 201 return patch_setter(corrupted_activation, index, clean_activation)
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/transformer_lens/patching.py:276, in layer_head_vector_patch_setter(corrupted_activation, index, clean_activation)
274 assert len(index) == 2
275 layer, head_index = index
--> 276 corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index]
278 return corrupted_activation
File ~/opt/anaconda3/envs/hallucination/lib/python3.9/site-packages/torch/utils/_device.py:77, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
75 if func in _device_constructors() and kwargs.get('device') is None:
76 kwargs['device'] = self.device
---> 77 return func(*args, **kwargs)
IndexError: index 8 is out of bounds for dimension 2 with size 8
System Info
Describe the characteristic of your environment:
transformer_lenswas installed from pip- I have reproduce it both on Mac and on Google Colab
- Python version 3.9
Additional context
The problem is that the used number of heads inside patching.generic_activation_patch is n_heads for key and value heads instead of n_key_value_heads. And the other issue inside the methods (patching.get_act_patch_attn_head_all_pos_every and patching.get_act_patch_attn_head_by_pos_every) stacking results from the attention heads patching is that the results for the key and value head patching are going to be in different dimention.
Resolve of this bug is proposed in this pull request - #981.
Checklist
- I have checked that there is no similar issue in the repo (required)