|
57 | 57 | ) |
58 | 58 |
|
59 | 59 | from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda, is_hip |
60 | | - |
61 | 60 | from tritonbench.utils.path_utils import add_ld_library_path |
| 61 | +from tritonbench.utils.python_utils import try_import |
62 | 62 | from tritonbench.utils.triton_op import is_fbcode |
63 | 63 |
|
64 | 64 |
|
65 | 65 | # [Optional] flash_attn v2 |
66 | | -try: |
| 66 | +with try_import("HAS_FLASH_V2"): |
67 | 67 | from flash_attn.flash_attn_interface import ( |
68 | 68 | flash_attn_qkvpacked_func as flash_attn_func, |
69 | 69 | ) |
70 | 70 |
|
71 | 71 | from .test_fmha_utils import make_packed_qkv |
72 | 72 |
|
73 | | - HAS_FLASH_V2 = True |
74 | | -except (ImportError, IOError, AttributeError): |
75 | | - HAS_FLASH_V2 = False |
76 | | - |
77 | 73 | HAS_CUDA_124 = ( |
78 | 74 | torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4" |
79 | 75 | ) |
|
83 | 79 | # only enabling the variants known to be working on B200 (trunk). |
84 | 80 | if not IS_B200: |
85 | 81 | # [Optional] flash_attn v3 |
86 | | - try: |
87 | | - torch_lib_path = os.path.join(os.path.dirname(__file__), "lib") |
88 | | - with add_ld_library_path(torch_lib_path): |
89 | | - from flash_attn_interface import flash_attn_func as flash_attn_v3 |
90 | | - HAS_FLASH_V3 = True |
91 | | - except (ImportError, IOError, AttributeError): |
| 82 | + with try_import("HAS_FLASH_V3"): |
92 | 83 | try: |
93 | | - from fa3.hopper.flash_attn_interface import flash_attn_func as flash_attn_v3 |
94 | | - |
95 | | - HAS_FLASH_V3 = True |
| 84 | + torch_lib_path = os.path.join(os.path.dirname(__file__), "lib") |
| 85 | + with add_ld_library_path(torch_lib_path): |
| 86 | + from flash_attn_interface import flash_attn_func as flash_attn_v3 |
96 | 87 | except (ImportError, IOError, AttributeError): |
97 | | - HAS_FLASH_V3 = False |
| 88 | + from fa3.hopper.flash_attn_interface import flash_attn_func as flash_attn_v3 |
98 | 89 |
|
99 | | - try: |
| 90 | + with try_import("HAS_TILELANG"): |
100 | 91 | import tilelang |
101 | 92 |
|
102 | 93 | from .tilelang_mha import tilelang_mha |
103 | 94 |
|
104 | | - HAS_TILELANG = True |
105 | | - except (ImportError, IOError, AttributeError, TypeError): |
106 | | - HAS_TILELANG = False |
107 | | - |
108 | 95 | # [Optional] ThunderKittens backend |
109 | | - try: |
| 96 | + with try_import("HAS_TK"): |
110 | 97 | from .tk import tk_attn |
111 | 98 |
|
112 | | - HAS_TK = True |
113 | | - except (ImportError, IOError, AttributeError): |
114 | | - HAS_TK = False |
115 | | - |
116 | 99 | # [Optional] JAX Pallas backend |
117 | | - try: |
| 100 | + with try_import("HAS_PALLAS"): |
118 | 101 | import jax |
119 | | - |
120 | 102 | from tritonbench.utils.jax_utils import torch_to_jax_tensor |
121 | 103 |
|
122 | 104 | from .pallas import mha as pallas_mha |
123 | 105 |
|
124 | | - HAS_PALLAS = True |
125 | | - except (ImportError, IOError, AttributeError): |
126 | | - HAS_PALLAS = False |
127 | | - |
128 | 106 | # [Optional] xformers backend |
129 | | -try: |
| 107 | +with try_import("HAS_XFORMERS"): |
130 | 108 | import xformers # @manual=//fair/xformers:xformers |
131 | 109 | import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers |
132 | 110 |
|
133 | 111 | from .test_fmha_utils import permute_qkv |
134 | 112 |
|
135 | | - HAS_XFORMERS = True |
136 | | -except (ImportError, IOError, AttributeError, TypeError): |
137 | | - HAS_XFORMERS = False |
138 | | - |
139 | 113 | from typing import Any, Generator, List |
140 | 114 |
|
141 | 115 | from tritonbench.utils.input import input_filter |
|
0 commit comments