diff --git a/tests/unit/components/test_attention.py b/tests/unit/components/test_attention.py index 572529cec..38cb3e0d3 100644 --- a/tests/unit/components/test_attention.py +++ b/tests/unit/components/test_attention.py @@ -128,3 +128,45 @@ def test_remove_einsum_from_complex_attn_linear(): # Check if the results are the same assert torch.allclose(result_new, result_old, atol=1e-4) + + +@pytest.mark.skipif( + not torch.backends.mps.is_available() and torch.__version__ != "2.8.0", + reason="Issue with F.linear issue exclusive to mps and PyTorch 2.8" + "https://github.com/pytorch/pytorch/issues/161640", +) +def test_cpu_mps_outputs_match(): + torch.manual_seed(0) + + cfg = { + "n_layers": 1, + "d_model": 48, + "n_ctx": 256, + "d_head": 16, + "n_heads": 3, + "load_in_4bit": False, + "dtype": torch.float32, + "act_fn": "relu", + } + + def init_weights(attn_layer: nn.Module): + nn.init.normal_(attn_layer.W_Q, mean=0.0, std=0.02) + nn.init.normal_(attn_layer.W_K, mean=0.0, std=0.02) + nn.init.normal_(attn_layer.W_V, mean=0.0, std=0.02) + nn.init.normal_(attn_layer.W_O, mean=0.0, std=0.02) + return attn_layer + + attn_cpu = Attention(cfg) + attn_cpu = init_weights(attn_cpu) + + attn_mps = Attention(cfg).to("mps") + attn_mps.load_state_dict(attn_cpu.state_dict(), strict=True) + + batch = 1 + input_cpu = torch.randn(batch, cfg["n_ctx"], cfg["d_model"]) + input_mps = input_cpu.to("mps") + + cpu_output = attn_cpu(input_cpu, input_cpu, input_cpu) + mps_output = attn_mps(input_mps, input_mps, input_mps) + + assert torch.allclose(cpu_output, mps_output.cpu()) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 0aee43814..1f227500b 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -299,10 +299,13 @@ def forward( if self.b_O.device != z.device: z = z.to(self.b_O.device) - out = F.linear( - z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), - w, - self.b_O, + z = einops.rearrange( + z, "batch pos head_index d_head -> batch pos (head_index d_head)" + ) + + out = ( + einops.einsum(z, w, "batch pos d_heads, d_model d_heads -> batch pos d_model") + + self.b_O ) else: # Explicitly calculate the attention result so it can be accessed by a hook