diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0f47767e4141..7fedc4e75441 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3872,11 +3872,13 @@ def test_attn_implementation_composite_models(self): for name, submodule in model.named_modules(): class_name = submodule.__class__.__name__ if ( - "SdpaAttention" in class_name - or "SdpaSelfAttention" in class_name - or "FlashAttention" in class_name + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation != "eager" ): - raise ValueError(f"The eager model should not have SDPA/FA2 attention layers but got {class_name}") + raise ValueError( + f"The eager model should not have SDPA/FA2 attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`" + ) @require_torch_sdpa def test_sdpa_can_dispatch_non_composite_models(self): @@ -3907,8 +3909,14 @@ def test_sdpa_can_dispatch_non_composite_models(self): for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError(f"The eager model should not have SDPA attention layers but got {class_name}") + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): + raise ValueError( + f"The eager model should not have SDPA attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`" + ) @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): @@ -3959,7 +3967,11 @@ def test_sdpa_can_dispatch_composite_models(self): for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): raise ValueError("The eager model should not have SDPA attention layers") @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @@ -4446,7 +4458,11 @@ def test_flash_attn_2_can_dispatch_composite_models(self): has_fa2 = False for name, submodule in model_fa2.named_modules(): class_name = submodule.__class__.__name__ - if "FlashAttention" in class_name: + if ( + "Attention" in class_name + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "flash_attention_2" + ): has_fa2 = True break if not has_fa2: