|
23 | 23 | from tensorrt_llm.commands.bench import main |
24 | 24 | from tensorrt_llm.functional import AllReduceStrategy |
25 | 25 |
|
| 26 | +# needed since LLM API uses MPI executor pool internally for TP>1, which leaks a thread on shutdown |
| 27 | +pytestmark = pytest.mark.threadleak(enabled=False) |
| 28 | + |
26 | 29 |
|
27 | 30 | class TimeoutError(Exception): |
28 | 31 | """Exception raised when a test times out.""" |
@@ -55,6 +58,71 @@ def timeout_handler(signum, frame): |
55 | 58 | signal.signal(signal.SIGALRM, old_handler) |
56 | 59 |
|
57 | 60 |
|
| 61 | +@pytest.fixture(scope="module", autouse=True) |
| 62 | +def prewarm_flashinfer_jit(): |
| 63 | + """Pre-warm FlashInfer JIT kernels before multi-GPU tests. |
| 64 | +
|
| 65 | + This prevents a race condition where multiple MPI ranks try to JIT-compile |
| 66 | + FlashInfer kernels simultaneously to the same cache directory, causing |
| 67 | + Ninja build failures like: "ninja: error: opening build log: No such file or directory" |
| 68 | +
|
| 69 | + By triggering the compilation in the main process first, the kernels are |
| 70 | + cached and available for all worker ranks. |
| 71 | + """ |
| 72 | + try: |
| 73 | + import flashinfer |
| 74 | + import flashinfer.page |
| 75 | + import flashinfer.sampling |
| 76 | + |
| 77 | + if torch.cuda.is_available(): |
| 78 | + # Prevent concurrent JIT warmup across multiple pytest processes (e.g., xdist). |
| 79 | + try: |
| 80 | + import fcntl # Linux-only |
| 81 | + except ImportError: |
| 82 | + fcntl = None |
| 83 | + |
| 84 | + lock_f = None |
| 85 | + if fcntl is not None: |
| 86 | + import pathlib |
| 87 | + import tempfile |
| 88 | + |
| 89 | + lock_path = pathlib.Path(tempfile.gettempdir()) / "flashinfer_jit_prewarm.lock" |
| 90 | + lock_f = open(lock_path, "w") |
| 91 | + fcntl.flock(lock_f.fileno(), fcntl.LOCK_EX) |
| 92 | + # Create dummy tensors to trigger kernel JIT compilation |
| 93 | + with torch.no_grad(): |
| 94 | + device = torch.device("cuda:0") |
| 95 | + |
| 96 | + # Trigger page kernel compilation |
| 97 | + try: |
| 98 | + # Force module loading (this triggers JIT compilation) |
| 99 | + _ = flashinfer.page.gen_page_module() |
| 100 | + except Exception as exc: # noqa: BLE001 |
| 101 | + import warnings |
| 102 | + |
| 103 | + warnings.warn(f"FlashInfer page-kernel prewarm failed: {exc!r}", RuntimeWarning) |
| 104 | + |
| 105 | + # Trigger sampling kernel compilation |
| 106 | + try: |
| 107 | + dummy_probs = torch.softmax(torch.randn(1, 100, device=device), dim=-1) |
| 108 | + _ = flashinfer.sampling.sampling_from_probs(dummy_probs, deterministic=True) |
| 109 | + except Exception as exc: # noqa: BLE001 |
| 110 | + import warnings |
| 111 | + |
| 112 | + warnings.warn( |
| 113 | + f"FlashInfer sampling-kernel prewarm failed: {exc!r}", RuntimeWarning |
| 114 | + ) |
| 115 | + |
| 116 | + torch.cuda.empty_cache() |
| 117 | + if lock_f is not None: |
| 118 | + lock_f.close() |
| 119 | + |
| 120 | + except ImportError: |
| 121 | + pass # FlashInfer not available |
| 122 | + |
| 123 | + yield |
| 124 | + |
| 125 | + |
58 | 126 | @pytest.fixture(scope="module") |
59 | 127 | def shared_dataset(llm_root): # noqa: F811 |
60 | 128 | """Prepare dataset once for all tests in this module.""" |
|
0 commit comments