Skip to content

Commit 4239a26

Browse files
authored
finalized t5 adapter (#1095)
* finalized t5 adapter * tested t5 architecture * fixed type issues * resolved experts mapping issues * Revert "resolved experts mapping issues" This reverts commit 9fa5125. * ran format
1 parent 0165122 commit 4239a26

File tree

15 files changed

+795
-56
lines changed

15 files changed

+795
-56
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""Acceptance test for T5 compatibility mode in TransformerBridge.
2+
3+
This test verifies that T5 can be loaded with TransformerBridge and that
4+
compatibility mode can be successfully enabled with proper hook registration.
5+
"""
6+
7+
import gc
8+
9+
import pytest
10+
import torch
11+
12+
from transformer_lens.model_bridge.bridge import TransformerBridge
13+
from transformer_lens.utilities.bridge_components import apply_fn_to_all_components
14+
15+
16+
class TestT5CompatibilityMode:
17+
"""Test T5 compatibility mode functionality."""
18+
19+
@pytest.fixture(autouse=True)
20+
def cleanup_after_test(self):
21+
"""Clean up memory after each test."""
22+
yield
23+
# Force garbage collection and clear CUDA cache
24+
if torch.cuda.is_available():
25+
torch.cuda.empty_cache()
26+
for _ in range(3):
27+
gc.collect()
28+
29+
@pytest.fixture
30+
def model_name(self):
31+
"""T5 model to test."""
32+
return "google-t5/t5-small"
33+
34+
@pytest.fixture
35+
def bridge_model(self, model_name):
36+
"""Load T5 model via TransformerBridge."""
37+
return TransformerBridge.boot_transformers(model_name, device="cpu")
38+
39+
def test_t5_loads_successfully(self, bridge_model, model_name):
40+
"""Test that T5 loads successfully via TransformerBridge."""
41+
assert bridge_model is not None
42+
assert bridge_model.cfg.model_name == model_name
43+
assert hasattr(bridge_model, "encoder_blocks")
44+
assert hasattr(bridge_model, "decoder_blocks")
45+
46+
def test_linear_bridge_submodules_exist(self, bridge_model):
47+
"""Test that AttentionBridge and MLPBridge have LinearBridge submodules.
48+
49+
This is critical for compatibility mode to work - without LinearBridge
50+
submodules, hook aliases like 'hook_q -> q.hook_out' will fail.
51+
"""
52+
# Check encoder attention
53+
encoder_attn = bridge_model.encoder_blocks[0].attn
54+
assert hasattr(encoder_attn, "q"), "Encoder attention missing q submodule"
55+
assert hasattr(encoder_attn, "k"), "Encoder attention missing k submodule"
56+
assert hasattr(encoder_attn, "v"), "Encoder attention missing v submodule"
57+
assert hasattr(encoder_attn, "o"), "Encoder attention missing o submodule"
58+
59+
# Verify they are LinearBridge instances, not raw Linear layers
60+
from transformer_lens.model_bridge.generalized_components.linear import (
61+
LinearBridge,
62+
)
63+
64+
assert isinstance(encoder_attn.q, LinearBridge), "q should be LinearBridge"
65+
assert isinstance(encoder_attn.k, LinearBridge), "k should be LinearBridge"
66+
assert isinstance(encoder_attn.v, LinearBridge), "v should be LinearBridge"
67+
assert isinstance(encoder_attn.o, LinearBridge), "o should be LinearBridge"
68+
69+
# Check decoder self-attention
70+
decoder_self_attn = bridge_model.decoder_blocks[0].self_attn
71+
assert hasattr(decoder_self_attn, "q"), "Decoder self-attn missing q submodule"
72+
assert isinstance(decoder_self_attn.q, LinearBridge), "q should be LinearBridge"
73+
74+
# Check decoder cross-attention
75+
decoder_cross_attn = bridge_model.decoder_blocks[0].cross_attn
76+
assert hasattr(decoder_cross_attn, "q"), "Decoder cross-attn missing q submodule"
77+
assert isinstance(decoder_cross_attn.q, LinearBridge), "q should be LinearBridge"
78+
79+
# Check encoder MLP
80+
encoder_mlp = bridge_model.encoder_blocks[0].mlp
81+
# Use getattr since 'in' is a Python keyword
82+
mlp_in = getattr(encoder_mlp, "in", None)
83+
mlp_out = getattr(encoder_mlp, "out", None)
84+
assert mlp_in is not None, "Encoder MLP missing 'in' submodule"
85+
assert mlp_out is not None, "Encoder MLP missing 'out' submodule"
86+
assert isinstance(mlp_in, LinearBridge), "in should be LinearBridge"
87+
assert isinstance(mlp_out, LinearBridge), "out should be LinearBridge"
88+
89+
def test_linear_bridge_hooks_accessible(self, bridge_model):
90+
"""Test that LinearBridge submodules have hook_out."""
91+
encoder_attn = bridge_model.encoder_blocks[0].attn
92+
93+
assert hasattr(encoder_attn.q, "hook_out"), "LinearBridge q missing hook_out"
94+
assert hasattr(encoder_attn.k, "hook_out"), "LinearBridge k missing hook_out"
95+
assert hasattr(encoder_attn.v, "hook_out"), "LinearBridge v missing hook_out"
96+
assert hasattr(encoder_attn.o, "hook_out"), "LinearBridge o missing hook_out"
97+
98+
# Verify they are HookPoints
99+
from transformer_lens.hook_points import HookPoint
100+
101+
assert isinstance(encoder_attn.q.hook_out, HookPoint)
102+
assert isinstance(encoder_attn.k.hook_out, HookPoint)
103+
assert isinstance(encoder_attn.v.hook_out, HookPoint)
104+
assert isinstance(encoder_attn.o.hook_out, HookPoint)
105+
106+
def test_compatibility_mode_enables_successfully(self, bridge_model):
107+
"""Test that compatibility mode can be enabled for T5.
108+
109+
This is the main acceptance test - compatibility mode should enable
110+
without errors and properly register all hooks.
111+
"""
112+
# Enable compatibility mode manually (avoiding full enable_compatibility_mode
113+
# which includes weight processing that doesn't work for T5 yet)
114+
bridge_model.compatibility_mode = True
115+
116+
def set_compatibility_mode(component):
117+
component.compatibility_mode = True
118+
component.disable_warnings = False
119+
120+
apply_fn_to_all_components(bridge_model, set_compatibility_mode)
121+
122+
# Re-initialize hook registry to include aliases
123+
bridge_model.clear_hook_registry()
124+
bridge_model._initialize_hook_registry()
125+
126+
# Verify compatibility mode is enabled
127+
assert bridge_model.compatibility_mode is True
128+
129+
def test_hook_registry_populated(self, bridge_model):
130+
"""Test that hook registry is populated after enabling compatibility mode."""
131+
# Enable compatibility mode
132+
bridge_model.compatibility_mode = True
133+
134+
def set_compatibility_mode(component):
135+
component.compatibility_mode = True
136+
component.disable_warnings = False
137+
138+
apply_fn_to_all_components(bridge_model, set_compatibility_mode)
139+
bridge_model.clear_hook_registry()
140+
bridge_model._initialize_hook_registry()
141+
142+
# Check that hooks are registered
143+
assert len(bridge_model._hook_registry) > 0, "Hook registry should not be empty"
144+
145+
# Should have hundreds of hooks (encoder + decoder)
146+
assert (
147+
len(bridge_model._hook_registry) > 500
148+
), f"Expected >500 hooks, got {len(bridge_model._hook_registry)}"
149+
150+
def test_critical_hooks_accessible(self, bridge_model):
151+
"""Test that critical hooks are accessible after compatibility mode."""
152+
# Enable compatibility mode
153+
bridge_model.compatibility_mode = True
154+
155+
def set_compatibility_mode(component):
156+
component.compatibility_mode = True
157+
component.disable_warnings = False
158+
159+
apply_fn_to_all_components(bridge_model, set_compatibility_mode)
160+
bridge_model.clear_hook_registry()
161+
bridge_model._initialize_hook_registry()
162+
163+
# Test critical encoder hooks
164+
critical_hooks = [
165+
"encoder_blocks.0.hook_in",
166+
"encoder_blocks.0.attn.q.hook_out",
167+
"encoder_blocks.0.attn.hook_out",
168+
"encoder_blocks.0.mlp.in.hook_out",
169+
"encoder_blocks.0.mlp.out.hook_out",
170+
# Decoder hooks
171+
"decoder_blocks.0.hook_in",
172+
"decoder_blocks.0.self_attn.q.hook_out",
173+
"decoder_blocks.0.cross_attn.k.hook_out",
174+
"decoder_blocks.0.mlp.in.hook_out",
175+
]
176+
177+
for hook_name in critical_hooks:
178+
assert (
179+
hook_name in bridge_model._hook_registry
180+
), f"Critical hook {hook_name} not found in registry"
181+
182+
def test_encoder_decoder_hook_counts(self, bridge_model):
183+
"""Test that both encoder and decoder have reasonable hook counts."""
184+
# Enable compatibility mode
185+
bridge_model.compatibility_mode = True
186+
187+
def set_compatibility_mode(component):
188+
component.compatibility_mode = True
189+
component.disable_warnings = False
190+
191+
apply_fn_to_all_components(bridge_model, set_compatibility_mode)
192+
bridge_model.clear_hook_registry()
193+
bridge_model._initialize_hook_registry()
194+
195+
# Count encoder and decoder hooks
196+
encoder_hooks = [h for h in bridge_model._hook_registry if "encoder" in h]
197+
decoder_hooks = [h for h in bridge_model._hook_registry if "decoder" in h]
198+
199+
assert len(encoder_hooks) > 0, "Should have encoder hooks"
200+
assert len(decoder_hooks) > 0, "Should have decoder hooks"
201+
202+
# Decoder should have more hooks (has cross-attention in addition to self-attention)
203+
assert len(decoder_hooks) > len(
204+
encoder_hooks
205+
), "Decoder should have more hooks than encoder"
206+
207+
def test_t5_block_bridge_hooks(self, bridge_model):
208+
"""Test that T5BlockBridge has the expected hooks."""
209+
# Check encoder block
210+
encoder_block = bridge_model.encoder_blocks[0]
211+
assert hasattr(encoder_block, "hook_in")
212+
assert hasattr(encoder_block, "hook_out")
213+
assert hasattr(encoder_block, "hook_resid_mid")
214+
215+
# Encoder blocks should NOT have hook_resid_mid2 (only 2 layers)
216+
assert not hasattr(encoder_block, "hook_resid_mid2")
217+
218+
# Check decoder block
219+
decoder_block = bridge_model.decoder_blocks[0]
220+
assert hasattr(decoder_block, "hook_in")
221+
assert hasattr(decoder_block, "hook_out")
222+
assert hasattr(decoder_block, "hook_resid_mid")
223+
224+
# Decoder blocks SHOULD have hook_resid_mid2 (3 layers - after cross-attn)
225+
assert hasattr(decoder_block, "hook_resid_mid2")
226+
227+
def test_rms_normalization_used(self, bridge_model):
228+
"""Test that T5 uses RMSNormalizationBridge throughout."""
229+
from transformer_lens.model_bridge.generalized_components.rms_normalization import (
230+
RMSNormalizationBridge,
231+
)
232+
233+
# Check encoder
234+
assert isinstance(bridge_model.encoder_blocks[0].ln1, RMSNormalizationBridge)
235+
assert isinstance(bridge_model.encoder_blocks[0].ln2, RMSNormalizationBridge)
236+
assert isinstance(bridge_model.encoder_ln_final, RMSNormalizationBridge)
237+
238+
# Check decoder
239+
assert isinstance(bridge_model.decoder_blocks[0].ln1, RMSNormalizationBridge)
240+
assert isinstance(bridge_model.decoder_blocks[0].ln2, RMSNormalizationBridge)
241+
assert isinstance(bridge_model.decoder_blocks[0].ln3, RMSNormalizationBridge)
242+
assert isinstance(bridge_model.decoder_ln_final, RMSNormalizationBridge)

transformer_lens/model_bridge/bridge.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -845,9 +845,22 @@ def _setup_no_processing_hooks(self) -> None:
845845
1. hook_z reshaping for proper head dimensions
846846
2. Wrapping HF attention forward to capture scores before softmax
847847
"""
848-
for block in self.blocks:
849-
if hasattr(block, "attn") and hasattr(block.attn, "setup_no_processing_hooks"):
850-
block.attn.setup_no_processing_hooks()
848+
# Handle both decoder-only (blocks) and encoder-decoder (encoder_blocks, decoder_blocks)
849+
blocks_to_process = []
850+
if hasattr(self, "blocks"):
851+
blocks_to_process.extend(self.blocks)
852+
if hasattr(self, "encoder_blocks"):
853+
blocks_to_process.extend(self.encoder_blocks)
854+
if hasattr(self, "decoder_blocks"):
855+
blocks_to_process.extend(self.decoder_blocks)
856+
857+
for block in blocks_to_process:
858+
# Handle both regular attn and self_attn/cross_attn naming
859+
for attn_name in ["attn", "self_attn", "cross_attn"]:
860+
if hasattr(block, attn_name):
861+
attn = getattr(block, attn_name)
862+
if hasattr(attn, "setup_no_processing_hooks"):
863+
attn.setup_no_processing_hooks()
851864

852865
def _enable_split_qkv_attention(self) -> None:
853866
"""Enable split Q/K/V computation for attention layers in no_processing mode.
@@ -859,7 +872,16 @@ def _enable_split_qkv_attention(self) -> None:
859872
Unlike enable_ht_computation_for_bridge, this ONLY affects attention layers,
860873
leaving MLPs to use their original HF weights.
861874
"""
862-
for block in self.blocks:
875+
# Handle both decoder-only (blocks) and encoder-decoder (encoder_blocks, decoder_blocks)
876+
blocks_to_process = []
877+
if hasattr(self, "blocks"):
878+
blocks_to_process.extend(self.blocks)
879+
if hasattr(self, "encoder_blocks"):
880+
blocks_to_process.extend(self.encoder_blocks)
881+
if hasattr(self, "decoder_blocks"):
882+
blocks_to_process.extend(self.decoder_blocks)
883+
884+
for block in blocks_to_process:
863885
if hasattr(block, "attn") and hasattr(block, "original_component"):
864886
hf_block = block.original_component
865887
if hasattr(hf_block, "attn"):
@@ -903,7 +925,16 @@ def _enable_native_layernorm_autograd(self) -> None:
903925
self.ln_f.config.use_hf_autograd = True
904926

905927
# Enable for all block normalization layers
906-
for block in self.blocks:
928+
# Handle both decoder-only (blocks) and encoder-decoder (encoder_blocks, decoder_blocks)
929+
blocks_to_process = []
930+
if hasattr(self, "blocks"):
931+
blocks_to_process.extend(self.blocks)
932+
if hasattr(self, "encoder_blocks"):
933+
blocks_to_process.extend(self.encoder_blocks)
934+
if hasattr(self, "decoder_blocks"):
935+
blocks_to_process.extend(self.decoder_blocks)
936+
937+
for block in blocks_to_process:
907938
# ln1 (pre-attention norm)
908939
if hasattr(block, "ln1") and isinstance(block.ln1, NormalizationBridge):
909940
if block.ln1.config is not None:

transformer_lens/model_bridge/generalized_components/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from transformer_lens.model_bridge.generalized_components.normalization import (
1919
NormalizationBridge,
2020
)
21+
from transformer_lens.model_bridge.generalized_components.rms_normalization import (
22+
RMSNormalizationBridge,
23+
)
2124

2225
from transformer_lens.model_bridge.generalized_components.linear import (
2326
LinearBridge,
@@ -33,6 +36,9 @@
3336
from transformer_lens.model_bridge.generalized_components.unembedding import (
3437
UnembeddingBridge,
3538
)
39+
from transformer_lens.model_bridge.generalized_components.t5_block import (
40+
T5BlockBridge,
41+
)
3642

3743
__all__ = [
3844
"AttentionBridge",
@@ -41,10 +47,12 @@
4147
"RotaryEmbeddingBridge",
4248
"PosEmbedBridge",
4349
"NormalizationBridge",
50+
"RMSNormalizationBridge",
4451
"JointQKVAttentionBridge",
4552
"JointGateUpMLPBridge",
4653
"LinearBridge",
4754
"MLPBridge",
4855
"MoEBridge",
4956
"UnembeddingBridge",
57+
"T5BlockBridge",
5058
]

transformer_lens/model_bridge/generalized_components/normalization.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Normalization bridge component implementation."""
22

3-
from typing import Any, Dict, Optional
3+
from typing import Any, Dict, Optional, cast
44

55
import torch
66

@@ -101,7 +101,12 @@ def forward(
101101
hidden_states = hidden_states * self.weight
102102
else:
103103
# Add bias if using LayerNorm and the original component has a bias
104-
hidden_states = hidden_states * self.weight + self.bias
104+
hidden_states = hidden_states * self.weight
105+
if (
106+
hasattr(self.original_component, "bias")
107+
and self.original_component.bias is not None
108+
):
109+
hidden_states = hidden_states + cast(torch.Tensor, self.original_component.bias)
105110

106111
result = hidden_states
107112

0 commit comments

Comments
 (0)