Skip to content

Commit 7d8375b

Browse files
committed
Update FlashInfer JIT header lookup
1 parent ce6182f commit 7d8375b

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

python/tvm/libinfo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def find_include_path(name=None, search_path=None, optional=False):
195195
include_path : list(string)
196196
List of all found paths to header files.
197197
"""
198-
if os.environ.get("TVM_HOME", None):
198+
if os.environ.get("TVM_SOURCE_DIR", None):
199+
source_dir = os.environ["TVM_SOURCE_DIR"]
200+
elif os.environ.get("TVM_HOME", None):
199201
source_dir = os.environ["TVM_HOME"]
200202
else:
201203
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
@@ -204,7 +206,7 @@ def find_include_path(name=None, search_path=None, optional=False):
204206
if os.path.isdir(os.path.join(source_dir, "include")):
205207
break
206208
else:
207-
raise AssertionError("Cannot find the source directory given ffi_dir: {ffi_dir}")
209+
raise AssertionError(f"Cannot find the source directory given ffi_dir: {ffi_dir}")
208210
third_party_dir = os.path.join(source_dir, "3rdparty")
209211

210212
header_path = []

python/tvm/relax/backend/cuda/flashinfer.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,45 @@ def get_object_file_path(src: Path) -> Path:
129129
FLASHINFER_INCLUDE_DIR,
130130
FLASHINFER_CSRC_DIR,
131131
FLASHINFER_TVM_BINDING_DIR,
132-
Path(tvm_home).resolve() / "include",
133-
Path(tvm_home).resolve() / "ffi" / "include",
134-
Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
135-
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
136132
] + CUTLASS_INCLUDE_DIRS
137133

134+
if "TVM_SOURCE_DIR" in os.environ or "TVM_HOME" in os.environ:
135+
tvm_home = (
136+
os.environ["TVM_SOURCE_DIR"]
137+
if "TVM_SOURCE_DIR" in os.environ
138+
else os.environ["TVM_HOME"]
139+
)
140+
include_paths += [
141+
Path(tvm_home).resolve() / "include",
142+
Path(tvm_home).resolve() / "ffi" / "include",
143+
Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
144+
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
145+
]
146+
else:
147+
tvm_package_path = Path(tvm.__file__).resolve().parent
148+
if (tvm_package_path / "include").exists():
149+
import tvm_ffi
150+
151+
tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent
152+
include_paths += [
153+
tvm_package_path / "include",
154+
tvm_package_path / "3rdparty" / "dmlc-core" / "include",
155+
tvm_ffi_package_path / "include",
156+
]
157+
elif (tvm_package_path.parent.parent / "include").exists():
158+
include_paths += [
159+
tvm_package_path.parent.parent / "include",
160+
tvm_package_path.parent.parent / "ffi" / "include",
161+
tvm_package_path.parent.parent / "ffi" / "3rdparty" / "dlpack" / "include",
162+
tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / "include",
163+
]
164+
else:
165+
# warning: TVM is not installed in the system.
166+
print(
167+
"Warning: Include path for TVM cannot be found. "
168+
"FlashInfer kernel compilation may fail due to missing headers."
169+
)
170+
138171
# ------------------------------------------------------------------------
139172
# 3) Function to compile a single source file
140173
# ------------------------------------------------------------------------

0 commit comments

Comments
 (0)