Skip to content

Commit a8dfdf8

Browse files
authored
[python_utils] add try import utils (#387)
1 parent f99db62 commit a8dfdf8

File tree

2 files changed

+24
-37
lines changed

2 files changed

+24
-37
lines changed

tritonbench/operators/flash_attention/operator.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,19 @@
5757
)
5858

5959
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda, is_hip
60-
6160
from tritonbench.utils.path_utils import add_ld_library_path
61+
from tritonbench.utils.python_utils import try_import
6262
from tritonbench.utils.triton_op import is_fbcode
6363

6464

6565
# [Optional] flash_attn v2
66-
try:
66+
with try_import("HAS_FLASH_V2"):
6767
from flash_attn.flash_attn_interface import (
6868
flash_attn_qkvpacked_func as flash_attn_func,
6969
)
7070

7171
from .test_fmha_utils import make_packed_qkv
7272

73-
HAS_FLASH_V2 = True
74-
except (ImportError, IOError, AttributeError):
75-
HAS_FLASH_V2 = False
76-
7773
HAS_CUDA_124 = (
7874
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4"
7975
)
@@ -83,59 +79,37 @@
8379
# only enabling the variants known to be working on B200 (trunk).
8480
if not IS_B200:
8581
# [Optional] flash_attn v3
86-
try:
87-
torch_lib_path = os.path.join(os.path.dirname(__file__), "lib")
88-
with add_ld_library_path(torch_lib_path):
89-
from flash_attn_interface import flash_attn_func as flash_attn_v3
90-
HAS_FLASH_V3 = True
91-
except (ImportError, IOError, AttributeError):
82+
with try_import("HAS_FLASH_V3"):
9283
try:
93-
from fa3.hopper.flash_attn_interface import flash_attn_func as flash_attn_v3
94-
95-
HAS_FLASH_V3 = True
84+
torch_lib_path = os.path.join(os.path.dirname(__file__), "lib")
85+
with add_ld_library_path(torch_lib_path):
86+
from flash_attn_interface import flash_attn_func as flash_attn_v3
9687
except (ImportError, IOError, AttributeError):
97-
HAS_FLASH_V3 = False
88+
from fa3.hopper.flash_attn_interface import flash_attn_func as flash_attn_v3
9889

99-
try:
90+
with try_import("HAS_TILELANG"):
10091
import tilelang
10192

10293
from .tilelang_mha import tilelang_mha
10394

104-
HAS_TILELANG = True
105-
except (ImportError, IOError, AttributeError, TypeError):
106-
HAS_TILELANG = False
107-
10895
# [Optional] ThunderKittens backend
109-
try:
96+
with try_import("HAS_TK"):
11097
from .tk import tk_attn
11198

112-
HAS_TK = True
113-
except (ImportError, IOError, AttributeError):
114-
HAS_TK = False
115-
11699
# [Optional] JAX Pallas backend
117-
try:
100+
with try_import("HAS_PALLAS"):
118101
import jax
119-
120102
from tritonbench.utils.jax_utils import torch_to_jax_tensor
121103

122104
from .pallas import mha as pallas_mha
123105

124-
HAS_PALLAS = True
125-
except (ImportError, IOError, AttributeError):
126-
HAS_PALLAS = False
127-
128106
# [Optional] xformers backend
129-
try:
107+
with try_import("HAS_XFORMERS"):
130108
import xformers # @manual=//fair/xformers:xformers
131109
import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers
132110

133111
from .test_fmha_utils import permute_qkv
134112

135-
HAS_XFORMERS = True
136-
except (ImportError, IOError, AttributeError, TypeError):
137-
HAS_XFORMERS = False
138-
139113
from typing import Any, Generator, List
140114

141115
from tritonbench.utils.input import input_filter

tritonbench/utils/python_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import inspect
2+
from contextlib import contextmanager
3+
4+
5+
@contextmanager
6+
def try_import(cond_name: str):
7+
frame = inspect.currentframe().f_back.f_back
8+
_caller_globals = frame.f_globals
9+
try:
10+
yield
11+
_caller_globals[cond_name] = True
12+
except (ImportError, ModuleNotFoundError) as e:
13+
_caller_globals[cond_name] = False

0 commit comments

Comments
 (0)