|
| 1 | +"""Acceptance test for T5 compatibility mode in TransformerBridge. |
| 2 | +
|
| 3 | +This test verifies that T5 can be loaded with TransformerBridge and that |
| 4 | +compatibility mode can be successfully enabled with proper hook registration. |
| 5 | +""" |
| 6 | + |
| 7 | +import gc |
| 8 | + |
| 9 | +import pytest |
| 10 | +import torch |
| 11 | + |
| 12 | +from transformer_lens.model_bridge.bridge import TransformerBridge |
| 13 | +from transformer_lens.utilities.bridge_components import apply_fn_to_all_components |
| 14 | + |
| 15 | + |
| 16 | +class TestT5CompatibilityMode: |
| 17 | + """Test T5 compatibility mode functionality.""" |
| 18 | + |
| 19 | + @pytest.fixture(autouse=True) |
| 20 | + def cleanup_after_test(self): |
| 21 | + """Clean up memory after each test.""" |
| 22 | + yield |
| 23 | + # Force garbage collection and clear CUDA cache |
| 24 | + if torch.cuda.is_available(): |
| 25 | + torch.cuda.empty_cache() |
| 26 | + for _ in range(3): |
| 27 | + gc.collect() |
| 28 | + |
| 29 | + @pytest.fixture |
| 30 | + def model_name(self): |
| 31 | + """T5 model to test.""" |
| 32 | + return "google-t5/t5-small" |
| 33 | + |
| 34 | + @pytest.fixture |
| 35 | + def bridge_model(self, model_name): |
| 36 | + """Load T5 model via TransformerBridge.""" |
| 37 | + return TransformerBridge.boot_transformers(model_name, device="cpu") |
| 38 | + |
| 39 | + def test_t5_loads_successfully(self, bridge_model, model_name): |
| 40 | + """Test that T5 loads successfully via TransformerBridge.""" |
| 41 | + assert bridge_model is not None |
| 42 | + assert bridge_model.cfg.model_name == model_name |
| 43 | + assert hasattr(bridge_model, "encoder_blocks") |
| 44 | + assert hasattr(bridge_model, "decoder_blocks") |
| 45 | + |
| 46 | + def test_linear_bridge_submodules_exist(self, bridge_model): |
| 47 | + """Test that AttentionBridge and MLPBridge have LinearBridge submodules. |
| 48 | +
|
| 49 | + This is critical for compatibility mode to work - without LinearBridge |
| 50 | + submodules, hook aliases like 'hook_q -> q.hook_out' will fail. |
| 51 | + """ |
| 52 | + # Check encoder attention |
| 53 | + encoder_attn = bridge_model.encoder_blocks[0].attn |
| 54 | + assert hasattr(encoder_attn, "q"), "Encoder attention missing q submodule" |
| 55 | + assert hasattr(encoder_attn, "k"), "Encoder attention missing k submodule" |
| 56 | + assert hasattr(encoder_attn, "v"), "Encoder attention missing v submodule" |
| 57 | + assert hasattr(encoder_attn, "o"), "Encoder attention missing o submodule" |
| 58 | + |
| 59 | + # Verify they are LinearBridge instances, not raw Linear layers |
| 60 | + from transformer_lens.model_bridge.generalized_components.linear import ( |
| 61 | + LinearBridge, |
| 62 | + ) |
| 63 | + |
| 64 | + assert isinstance(encoder_attn.q, LinearBridge), "q should be LinearBridge" |
| 65 | + assert isinstance(encoder_attn.k, LinearBridge), "k should be LinearBridge" |
| 66 | + assert isinstance(encoder_attn.v, LinearBridge), "v should be LinearBridge" |
| 67 | + assert isinstance(encoder_attn.o, LinearBridge), "o should be LinearBridge" |
| 68 | + |
| 69 | + # Check decoder self-attention |
| 70 | + decoder_self_attn = bridge_model.decoder_blocks[0].self_attn |
| 71 | + assert hasattr(decoder_self_attn, "q"), "Decoder self-attn missing q submodule" |
| 72 | + assert isinstance(decoder_self_attn.q, LinearBridge), "q should be LinearBridge" |
| 73 | + |
| 74 | + # Check decoder cross-attention |
| 75 | + decoder_cross_attn = bridge_model.decoder_blocks[0].cross_attn |
| 76 | + assert hasattr(decoder_cross_attn, "q"), "Decoder cross-attn missing q submodule" |
| 77 | + assert isinstance(decoder_cross_attn.q, LinearBridge), "q should be LinearBridge" |
| 78 | + |
| 79 | + # Check encoder MLP |
| 80 | + encoder_mlp = bridge_model.encoder_blocks[0].mlp |
| 81 | + # Use getattr since 'in' is a Python keyword |
| 82 | + mlp_in = getattr(encoder_mlp, "in", None) |
| 83 | + mlp_out = getattr(encoder_mlp, "out", None) |
| 84 | + assert mlp_in is not None, "Encoder MLP missing 'in' submodule" |
| 85 | + assert mlp_out is not None, "Encoder MLP missing 'out' submodule" |
| 86 | + assert isinstance(mlp_in, LinearBridge), "in should be LinearBridge" |
| 87 | + assert isinstance(mlp_out, LinearBridge), "out should be LinearBridge" |
| 88 | + |
| 89 | + def test_linear_bridge_hooks_accessible(self, bridge_model): |
| 90 | + """Test that LinearBridge submodules have hook_out.""" |
| 91 | + encoder_attn = bridge_model.encoder_blocks[0].attn |
| 92 | + |
| 93 | + assert hasattr(encoder_attn.q, "hook_out"), "LinearBridge q missing hook_out" |
| 94 | + assert hasattr(encoder_attn.k, "hook_out"), "LinearBridge k missing hook_out" |
| 95 | + assert hasattr(encoder_attn.v, "hook_out"), "LinearBridge v missing hook_out" |
| 96 | + assert hasattr(encoder_attn.o, "hook_out"), "LinearBridge o missing hook_out" |
| 97 | + |
| 98 | + # Verify they are HookPoints |
| 99 | + from transformer_lens.hook_points import HookPoint |
| 100 | + |
| 101 | + assert isinstance(encoder_attn.q.hook_out, HookPoint) |
| 102 | + assert isinstance(encoder_attn.k.hook_out, HookPoint) |
| 103 | + assert isinstance(encoder_attn.v.hook_out, HookPoint) |
| 104 | + assert isinstance(encoder_attn.o.hook_out, HookPoint) |
| 105 | + |
| 106 | + def test_compatibility_mode_enables_successfully(self, bridge_model): |
| 107 | + """Test that compatibility mode can be enabled for T5. |
| 108 | +
|
| 109 | + This is the main acceptance test - compatibility mode should enable |
| 110 | + without errors and properly register all hooks. |
| 111 | + """ |
| 112 | + # Enable compatibility mode manually (avoiding full enable_compatibility_mode |
| 113 | + # which includes weight processing that doesn't work for T5 yet) |
| 114 | + bridge_model.compatibility_mode = True |
| 115 | + |
| 116 | + def set_compatibility_mode(component): |
| 117 | + component.compatibility_mode = True |
| 118 | + component.disable_warnings = False |
| 119 | + |
| 120 | + apply_fn_to_all_components(bridge_model, set_compatibility_mode) |
| 121 | + |
| 122 | + # Re-initialize hook registry to include aliases |
| 123 | + bridge_model.clear_hook_registry() |
| 124 | + bridge_model._initialize_hook_registry() |
| 125 | + |
| 126 | + # Verify compatibility mode is enabled |
| 127 | + assert bridge_model.compatibility_mode is True |
| 128 | + |
| 129 | + def test_hook_registry_populated(self, bridge_model): |
| 130 | + """Test that hook registry is populated after enabling compatibility mode.""" |
| 131 | + # Enable compatibility mode |
| 132 | + bridge_model.compatibility_mode = True |
| 133 | + |
| 134 | + def set_compatibility_mode(component): |
| 135 | + component.compatibility_mode = True |
| 136 | + component.disable_warnings = False |
| 137 | + |
| 138 | + apply_fn_to_all_components(bridge_model, set_compatibility_mode) |
| 139 | + bridge_model.clear_hook_registry() |
| 140 | + bridge_model._initialize_hook_registry() |
| 141 | + |
| 142 | + # Check that hooks are registered |
| 143 | + assert len(bridge_model._hook_registry) > 0, "Hook registry should not be empty" |
| 144 | + |
| 145 | + # Should have hundreds of hooks (encoder + decoder) |
| 146 | + assert ( |
| 147 | + len(bridge_model._hook_registry) > 500 |
| 148 | + ), f"Expected >500 hooks, got {len(bridge_model._hook_registry)}" |
| 149 | + |
| 150 | + def test_critical_hooks_accessible(self, bridge_model): |
| 151 | + """Test that critical hooks are accessible after compatibility mode.""" |
| 152 | + # Enable compatibility mode |
| 153 | + bridge_model.compatibility_mode = True |
| 154 | + |
| 155 | + def set_compatibility_mode(component): |
| 156 | + component.compatibility_mode = True |
| 157 | + component.disable_warnings = False |
| 158 | + |
| 159 | + apply_fn_to_all_components(bridge_model, set_compatibility_mode) |
| 160 | + bridge_model.clear_hook_registry() |
| 161 | + bridge_model._initialize_hook_registry() |
| 162 | + |
| 163 | + # Test critical encoder hooks |
| 164 | + critical_hooks = [ |
| 165 | + "encoder_blocks.0.hook_in", |
| 166 | + "encoder_blocks.0.attn.q.hook_out", |
| 167 | + "encoder_blocks.0.attn.hook_out", |
| 168 | + "encoder_blocks.0.mlp.in.hook_out", |
| 169 | + "encoder_blocks.0.mlp.out.hook_out", |
| 170 | + # Decoder hooks |
| 171 | + "decoder_blocks.0.hook_in", |
| 172 | + "decoder_blocks.0.self_attn.q.hook_out", |
| 173 | + "decoder_blocks.0.cross_attn.k.hook_out", |
| 174 | + "decoder_blocks.0.mlp.in.hook_out", |
| 175 | + ] |
| 176 | + |
| 177 | + for hook_name in critical_hooks: |
| 178 | + assert ( |
| 179 | + hook_name in bridge_model._hook_registry |
| 180 | + ), f"Critical hook {hook_name} not found in registry" |
| 181 | + |
| 182 | + def test_encoder_decoder_hook_counts(self, bridge_model): |
| 183 | + """Test that both encoder and decoder have reasonable hook counts.""" |
| 184 | + # Enable compatibility mode |
| 185 | + bridge_model.compatibility_mode = True |
| 186 | + |
| 187 | + def set_compatibility_mode(component): |
| 188 | + component.compatibility_mode = True |
| 189 | + component.disable_warnings = False |
| 190 | + |
| 191 | + apply_fn_to_all_components(bridge_model, set_compatibility_mode) |
| 192 | + bridge_model.clear_hook_registry() |
| 193 | + bridge_model._initialize_hook_registry() |
| 194 | + |
| 195 | + # Count encoder and decoder hooks |
| 196 | + encoder_hooks = [h for h in bridge_model._hook_registry if "encoder" in h] |
| 197 | + decoder_hooks = [h for h in bridge_model._hook_registry if "decoder" in h] |
| 198 | + |
| 199 | + assert len(encoder_hooks) > 0, "Should have encoder hooks" |
| 200 | + assert len(decoder_hooks) > 0, "Should have decoder hooks" |
| 201 | + |
| 202 | + # Decoder should have more hooks (has cross-attention in addition to self-attention) |
| 203 | + assert len(decoder_hooks) > len( |
| 204 | + encoder_hooks |
| 205 | + ), "Decoder should have more hooks than encoder" |
| 206 | + |
| 207 | + def test_t5_block_bridge_hooks(self, bridge_model): |
| 208 | + """Test that T5BlockBridge has the expected hooks.""" |
| 209 | + # Check encoder block |
| 210 | + encoder_block = bridge_model.encoder_blocks[0] |
| 211 | + assert hasattr(encoder_block, "hook_in") |
| 212 | + assert hasattr(encoder_block, "hook_out") |
| 213 | + assert hasattr(encoder_block, "hook_resid_mid") |
| 214 | + |
| 215 | + # Encoder blocks should NOT have hook_resid_mid2 (only 2 layers) |
| 216 | + assert not hasattr(encoder_block, "hook_resid_mid2") |
| 217 | + |
| 218 | + # Check decoder block |
| 219 | + decoder_block = bridge_model.decoder_blocks[0] |
| 220 | + assert hasattr(decoder_block, "hook_in") |
| 221 | + assert hasattr(decoder_block, "hook_out") |
| 222 | + assert hasattr(decoder_block, "hook_resid_mid") |
| 223 | + |
| 224 | + # Decoder blocks SHOULD have hook_resid_mid2 (3 layers - after cross-attn) |
| 225 | + assert hasattr(decoder_block, "hook_resid_mid2") |
| 226 | + |
| 227 | + def test_rms_normalization_used(self, bridge_model): |
| 228 | + """Test that T5 uses RMSNormalizationBridge throughout.""" |
| 229 | + from transformer_lens.model_bridge.generalized_components.rms_normalization import ( |
| 230 | + RMSNormalizationBridge, |
| 231 | + ) |
| 232 | + |
| 233 | + # Check encoder |
| 234 | + assert isinstance(bridge_model.encoder_blocks[0].ln1, RMSNormalizationBridge) |
| 235 | + assert isinstance(bridge_model.encoder_blocks[0].ln2, RMSNormalizationBridge) |
| 236 | + assert isinstance(bridge_model.encoder_ln_final, RMSNormalizationBridge) |
| 237 | + |
| 238 | + # Check decoder |
| 239 | + assert isinstance(bridge_model.decoder_blocks[0].ln1, RMSNormalizationBridge) |
| 240 | + assert isinstance(bridge_model.decoder_blocks[0].ln2, RMSNormalizationBridge) |
| 241 | + assert isinstance(bridge_model.decoder_blocks[0].ln3, RMSNormalizationBridge) |
| 242 | + assert isinstance(bridge_model.decoder_ln_final, RMSNormalizationBridge) |
0 commit comments