Skip to content

Commit ce7a42f

Browse files
authored
[https://nvbugs/5731717][fix] fixed flashinfer build race condition during test (#9983)
Signed-off-by: Eran Geva <[email protected]>
1 parent 8ba8699 commit ce7a42f

File tree

3 files changed

+85
-2
lines changed

3 files changed

+85
-2
lines changed

tests/integration/defs/test_unittests.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_unittests_v2(llm_root, llm_venv, case: str, output_dir, request):
125125
f'results-sub-unittests-{case_fn}.xml')
126126

127127
command = [
128-
'-m', 'pytest', ignore_opt, "-v", "--timeout=2400",
128+
'-m', 'pytest', ignore_opt, "-v", "--tb=short", "-rF", "--timeout=2400",
129129
"--timeout-method=thread"
130130
]
131131
if test_prefix:
@@ -153,7 +153,19 @@ def run_command(cmd, num_workers=1):
153153
cwd=test_root,
154154
env=env,
155155
)
156-
except CalledProcessError:
156+
except CalledProcessError as e:
157+
print(f"\n{'='*60}")
158+
print(f"UNITTEST FAILED with exit code: {e.returncode}")
159+
print(f"Command: {' '.join(cmd)}")
160+
if hasattr(e, 'stdout') and e.stdout:
161+
print(
162+
f"STDOUT:\n{e.stdout.decode() if isinstance(e.stdout, bytes) else e.stdout}"
163+
)
164+
if hasattr(e, 'stderr') and e.stderr:
165+
print(
166+
f"STDERR:\n{e.stderr.decode() if isinstance(e.stderr, bytes) else e.stderr}"
167+
)
168+
print(f"{'='*60}\n")
157169
return False
158170
return True
159171

tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from tensorrt_llm.commands.bench import main
2424
from tensorrt_llm.functional import AllReduceStrategy
2525

26+
# needed since LLM API uses MPI executor pool internally for TP>1, which leaks a thread on shutdown
27+
pytestmark = pytest.mark.threadleak(enabled=False)
28+
2629

2730
class TimeoutError(Exception):
2831
"""Exception raised when a test times out."""
@@ -55,6 +58,71 @@ def timeout_handler(signum, frame):
5558
signal.signal(signal.SIGALRM, old_handler)
5659

5760

61+
@pytest.fixture(scope="module", autouse=True)
62+
def prewarm_flashinfer_jit():
63+
"""Pre-warm FlashInfer JIT kernels before multi-GPU tests.
64+
65+
This prevents a race condition where multiple MPI ranks try to JIT-compile
66+
FlashInfer kernels simultaneously to the same cache directory, causing
67+
Ninja build failures like: "ninja: error: opening build log: No such file or directory"
68+
69+
By triggering the compilation in the main process first, the kernels are
70+
cached and available for all worker ranks.
71+
"""
72+
try:
73+
import flashinfer
74+
import flashinfer.page
75+
import flashinfer.sampling
76+
77+
if torch.cuda.is_available():
78+
# Prevent concurrent JIT warmup across multiple pytest processes (e.g., xdist).
79+
try:
80+
import fcntl # Linux-only
81+
except ImportError:
82+
fcntl = None
83+
84+
lock_f = None
85+
if fcntl is not None:
86+
import pathlib
87+
import tempfile
88+
89+
lock_path = pathlib.Path(tempfile.gettempdir()) / "flashinfer_jit_prewarm.lock"
90+
lock_f = open(lock_path, "w")
91+
fcntl.flock(lock_f.fileno(), fcntl.LOCK_EX)
92+
# Create dummy tensors to trigger kernel JIT compilation
93+
with torch.no_grad():
94+
device = torch.device("cuda:0")
95+
96+
# Trigger page kernel compilation
97+
try:
98+
# Force module loading (this triggers JIT compilation)
99+
_ = flashinfer.page.gen_page_module()
100+
except Exception as exc: # noqa: BLE001
101+
import warnings
102+
103+
warnings.warn(f"FlashInfer page-kernel prewarm failed: {exc!r}", RuntimeWarning)
104+
105+
# Trigger sampling kernel compilation
106+
try:
107+
dummy_probs = torch.softmax(torch.randn(1, 100, device=device), dim=-1)
108+
_ = flashinfer.sampling.sampling_from_probs(dummy_probs, deterministic=True)
109+
except Exception as exc: # noqa: BLE001
110+
import warnings
111+
112+
warnings.warn(
113+
f"FlashInfer sampling-kernel prewarm failed: {exc!r}", RuntimeWarning
114+
)
115+
116+
torch.cuda.empty_cache()
117+
if lock_f is not None:
118+
lock_f.close()
119+
120+
except ImportError:
121+
pass # FlashInfer not available
122+
123+
yield
124+
125+
58126
@pytest.fixture(scope="module")
59127
def shared_dataset(llm_root): # noqa: F811
60128
"""Prepare dataset once for all tests in this module."""

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
1313
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
1414

15+
# needed since MPI executor pool leaks a thread (_manager_spawn) on shutdown
16+
pytestmark = pytest.mark.threadleak(enabled=False)
17+
1518

1619
class RMSNorm(torch.nn.Module):
1720
"""Implementation of LlamaRMSNorm."""

0 commit comments

Comments
 (0)