diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..fce04bfa2d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -716,10 +716,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_unfused_attention = False if qkv_format == "thd": - logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type - ) - use_fused_attention = False + if cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", + softmax_type, + ) + use_fused_attention = False logger.debug( "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", softmax_type,