|
15 | 15 | import traceback |
16 | 16 | from datetime import datetime |
17 | 17 | from pathlib import Path |
18 | | -from typing import Optional |
| 18 | +from typing import Iterator, Optional |
19 | 19 | from typing_extensions import override |
20 | 20 |
|
21 | 21 | from amdsharktuner import common, libtuner |
@@ -83,13 +83,17 @@ def insert_placeholder_input_file(argv: list[str]) -> list[str]: |
83 | 83 | return [argv[0], "fusilli.mlir"] + argv[1:] |
84 | 84 |
|
85 | 85 |
|
86 | | -def parse_args() -> tuple[argparse.Namespace, list[str]]: |
| 86 | +def parse_args(argv: list[str]) -> tuple[argparse.Namespace, list[str]]: |
87 | 87 | """Parse command line arguments. |
88 | 88 |
|
| 89 | + Args: |
| 90 | + argv: Command line arguments to parse (typically sys.argv). |
| 91 | +
|
89 | 92 | Returns: |
90 | 93 | A tuple of (parsed_args, fusilli_op_args) where fusilli_op_args contains |
91 | 94 | the Fusilli operation arguments (conv, matmul parameters). |
92 | 95 | """ |
| 96 | + |
93 | 97 | parser = argparse.ArgumentParser( |
94 | 98 | description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter |
95 | 99 | ) |
@@ -138,10 +142,17 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]: |
138 | 142 | # Placeholder to satisfy libtuner's required input_file argument. |
139 | 143 | # Fusilli generates benchmark files at runtime, not from a pre-existing file. |
140 | 144 | # TODO(Bangtian): Remove dispatch tuner's input file requirement, then use dispatch tuner. |
141 | | - sys.argv = insert_placeholder_input_file(sys.argv) |
142 | | - args = libtuner.parse_arguments(parser) |
| 145 | + argv_with_placeholder = insert_placeholder_input_file(argv) |
143 | 146 |
|
144 | | - if "--codegen-pipeline" not in sys.argv: |
| 147 | + # Temporarily override sys.argv for libtuner.parse_arguments. |
| 148 | + original_argv = sys.argv |
| 149 | + sys.argv = argv_with_placeholder |
| 150 | + try: |
| 151 | + args = libtuner.parse_arguments(parser) |
| 152 | + finally: |
| 153 | + sys.argv = original_argv |
| 154 | + |
| 155 | + if "--codegen-pipeline" not in argv_with_placeholder: |
145 | 156 | # Default to tile_and_fuse for Fusilli operations. |
146 | 157 | args.codegen_pipeline = libtuner.CodegenPipelines.llvmgpu_tile_and_fuse |
147 | 158 |
|
@@ -174,12 +185,12 @@ def load_commands_from_file_or_args( |
174 | 185 |
|
175 | 186 |
|
176 | 187 | def build_compile_args(compile_command: str, benchmarks_dir: Path) -> list[str]: |
177 | | - fusilli_compile_flags = shlex.split(compile_command) |
| 188 | + fusilli_compile_flags: list[str] = shlex.split(compile_command) |
178 | 189 |
|
179 | 190 | # Start with iree-compile and filter out unwanted flags from fusilli flags. |
180 | 191 | # See fusilli/include/fusilli/backend/compile_command.h for flag formats. |
181 | | - compile_args = ["iree-compile"] |
182 | | - args_iter = iter(fusilli_compile_flags[1:]) |
| 192 | + compile_args: list[str] = ["iree-compile"] |
| 193 | + args_iter: Iterator[str] = iter(fusilli_compile_flags[1:]) |
183 | 194 | for arg in args_iter: |
184 | 195 | # Skip output flag (Fusilli generates "-o" as separate argument + path). |
185 | 196 | if arg == "-o": |
@@ -211,62 +222,82 @@ def run_fusilli_benchmark_driver( |
211 | 222 | ) -> None: |
212 | 223 | # Use --dump to generate MLIR and compile command artifacts, --iter 1 since |
213 | 224 | # we only need file generation (not actual benchmarking). |
214 | | - cmd = [fusilli_driver, "--dump", "--iter", "1"] + cli_args |
| 225 | + cmd: list[str] = [fusilli_driver, "--dump", "--iter", "1"] + cli_args |
215 | 226 |
|
216 | 227 | # Override FUSILLI_CACHE_DIR to control where Fusilli dumps the generated |
217 | 228 | # MLIR and compilation flags. |
218 | | - env = os.environ.copy() |
| 229 | + env: dict[str, str] = os.environ.copy() |
219 | 230 | env["FUSILLI_CACHE_DIR"] = str(cache_dir) |
220 | 231 |
|
221 | 232 | logging.info(f"> {shlex.join(cmd)}") |
222 | 233 | logging.info(f" FUSILLI_CACHE_DIR={cache_dir}") |
223 | 234 |
|
224 | | - result = subprocess.run(cmd, env=env, capture_output=True, text=True) |
225 | | - |
226 | | - # Log stdout even on success for debugging |
227 | | - if result.stdout: |
228 | | - logging.debug(f"Fusilli benchmark driver stdout:\n{result.stdout}") |
| 235 | + result: subprocess.CompletedProcess[str] = subprocess.run( |
| 236 | + cmd, env=env, capture_output=True, text=True |
| 237 | + ) |
229 | 238 |
|
230 | | - if result.returncode != 0: |
231 | | - logging.error( |
232 | | - f"Fusilli benchmark driver failed with return code {result.returncode}" |
233 | | - ) |
| 239 | + # Exit early on success. |
| 240 | + if result.returncode == 0: |
234 | 241 | if result.stdout: |
235 | | - logging.error(f"stdout: {result.stdout}") |
236 | | - if result.stderr: |
237 | | - logging.error(f"stderr: {result.stderr}") |
238 | | - raise RuntimeError( |
239 | | - f"Fusilli benchmark driver failed with code {result.returncode}" |
240 | | - ) |
| 242 | + logging.debug(f"Fusilli benchmark driver stdout:\n{result.stdout}") |
| 243 | + return |
| 244 | + |
| 245 | + # Handle failure. |
| 246 | + logging.error( |
| 247 | + f"Fusilli benchmark driver failed with return code {result.returncode}" |
| 248 | + ) |
| 249 | + if result.stdout: |
| 250 | + logging.error(f"stdout: {result.stdout}") |
| 251 | + if result.stderr: |
| 252 | + logging.error(f"stderr: {result.stderr}") |
| 253 | + raise RuntimeError(f"Fusilli benchmark driver failed with code {result.returncode}") |
241 | 254 |
|
242 | 255 |
|
243 | | -def find_cached_artifacts(cache_dir: Path) -> tuple[Path, Path]: |
| 256 | +def find_cached_artifacts(base_dir: Path) -> tuple[Path, Path]: |
244 | 257 | """Find source MLIR and compile command from Fusilli cache. |
245 | 258 |
|
| 259 | + Fusilli cache structure controlled by FUSILLI_CACHE_DIR environment variable: |
| 260 | + base_dir/ # User-specified or temp directory (FUSILLI_CACHE_DIR) |
| 261 | + .cache/fusilli/ # Fusilli's internal cache structure |
| 262 | + <graph_hash>/ # Graph-specific directory (one per operation) |
| 263 | + iree-compile-input.mlir # Generated MLIR for the operation |
| 264 | + iree-compile-command.txt # Compile command used by Fusilli |
| 265 | +
|
| 266 | + Args: |
| 267 | + base_dir: The base directory where FUSILLI_CACHE_DIR environment variable was set to. |
| 268 | + Fusilli creates its cache at base_dir/.cache/fusilli/ |
| 269 | +
|
246 | 270 | Returns: |
247 | 271 | Tuple of (source_mlir_path, compile_command_path). |
248 | 272 | """ |
249 | | - fusilli_cache = cache_dir / ".cache" / "fusilli" |
| 273 | + fusilli_cache: Path = base_dir / ".cache" / "fusilli" |
250 | 274 |
|
251 | 275 | if not fusilli_cache.exists(): |
252 | 276 | raise FileNotFoundError(f"Fusilli cache not found at {fusilli_cache}") |
253 | 277 |
|
254 | 278 | # Find the graph directory (there should be exactly one after running |
255 | 279 | # with --dump). |
256 | | - graph_dirs = list(fusilli_cache.iterdir()) |
| 280 | + graph_dirs: list[Path] = list(fusilli_cache.iterdir()) |
257 | 281 | if not graph_dirs: |
258 | 282 | raise FileNotFoundError(f"No graph directories found in {fusilli_cache}") |
259 | 283 |
|
260 | | - graph_dir = graph_dirs[0] |
| 284 | + graph_dir: Path = graph_dirs[0] |
261 | 285 |
|
262 | | - source_mlir_path = graph_dir / "iree-compile-input.mlir" |
263 | | - compile_command_path = graph_dir / "iree-compile-command.txt" |
| 286 | + source_mlir_path: Path = graph_dir / "iree-compile-input.mlir" |
| 287 | + compile_command_path: Path = graph_dir / "iree-compile-command.txt" |
264 | 288 |
|
265 | 289 | if not source_mlir_path.exists(): |
266 | 290 | raise FileNotFoundError(f"Source MLIR not found at {source_mlir_path}") |
267 | 291 | if not compile_command_path.exists(): |
268 | 292 | raise FileNotFoundError(f"Compile command not found at {compile_command_path}") |
269 | 293 |
|
| 294 | + # Validate paths to prevent path traversal attacks. |
| 295 | + try: |
| 296 | + source_mlir_path.resolve().relative_to(base_dir.resolve()) |
| 297 | + compile_command_path.resolve().relative_to(base_dir.resolve()) |
| 298 | + except ValueError as e: |
| 299 | + raise ValueError(f"Path traversal detected: {e}") |
| 300 | + |
270 | 301 | return source_mlir_path, compile_command_path |
271 | 302 |
|
272 | 303 |
|
@@ -345,12 +376,11 @@ def process_fusilli_command( |
345 | 376 | # Set up temporary directory. |
346 | 377 | if args.tmp_dir: |
347 | 378 | tmp_dir = Path(args.tmp_dir) |
348 | | - if tmp_dir.exists(): |
349 | | - logging.warning(f"Removing existing temporary directory: {tmp_dir}") |
350 | | - shutil.rmtree(tmp_dir) |
351 | 379 | tmp_dir.mkdir(parents=True, exist_ok=True) |
352 | 380 | logging.info(f"Using user-specified temporary directory: {tmp_dir}") |
353 | 381 | else: |
| 382 | + # Ensure parent directory exists before creating temp directory. |
| 383 | + Path("fusilli_tuner").mkdir(exist_ok=True) |
354 | 384 | tmp_dir = Path(tempfile.mkdtemp(dir="fusilli_tuner", prefix="fusilli_cache_")) |
355 | 385 | logging.info(f"Created temporary directory: {tmp_dir}") |
356 | 386 |
|
@@ -445,8 +475,7 @@ def process_fusilli_command( |
445 | 475 |
|
446 | 476 | def main() -> None: |
447 | 477 | """Main entry point for the Fusilli tuner.""" |
448 | | - parsed_args = parse_args() |
449 | | - args, fusilli_op_args = parsed_args |
| 478 | + args, fusilli_op_args = parse_args(sys.argv) |
450 | 479 |
|
451 | 480 | if args.commands_file and fusilli_op_args: |
452 | 481 | raise ValueError("Cannot specify both --commands-file and --fusilli-args") |
|
0 commit comments