Skip to content

Commit 74ff8a2

Browse files
committed
Address code review feedback
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
1 parent e57726b commit 74ff8a2

File tree

2 files changed

+178
-76
lines changed

2 files changed

+178
-76
lines changed

amdsharktuner/fusilli_tuner/fusilli_tuner.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import traceback
1616
from datetime import datetime
1717
from pathlib import Path
18-
from typing import Optional
18+
from typing import Iterator, Optional
1919
from typing_extensions import override
2020

2121
from amdsharktuner import common, libtuner
@@ -83,13 +83,17 @@ def insert_placeholder_input_file(argv: list[str]) -> list[str]:
8383
return [argv[0], "fusilli.mlir"] + argv[1:]
8484

8585

86-
def parse_args() -> tuple[argparse.Namespace, list[str]]:
86+
def parse_args(argv: list[str]) -> tuple[argparse.Namespace, list[str]]:
8787
"""Parse command line arguments.
8888
89+
Args:
90+
argv: Command line arguments to parse (typically sys.argv).
91+
8992
Returns:
9093
A tuple of (parsed_args, fusilli_op_args) where fusilli_op_args contains
9194
the Fusilli operation arguments (conv, matmul parameters).
9295
"""
96+
9397
parser = argparse.ArgumentParser(
9498
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
9599
)
@@ -138,10 +142,17 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
138142
# Placeholder to satisfy libtuner's required input_file argument.
139143
# Fusilli generates benchmark files at runtime, not from a pre-existing file.
140144
# TODO(Bangtian): Remove dispatch tuner's input file requirement, then use dispatch tuner.
141-
sys.argv = insert_placeholder_input_file(sys.argv)
142-
args = libtuner.parse_arguments(parser)
145+
argv_with_placeholder = insert_placeholder_input_file(argv)
143146

144-
if "--codegen-pipeline" not in sys.argv:
147+
# Temporarily override sys.argv for libtuner.parse_arguments.
148+
original_argv = sys.argv
149+
sys.argv = argv_with_placeholder
150+
try:
151+
args = libtuner.parse_arguments(parser)
152+
finally:
153+
sys.argv = original_argv
154+
155+
if "--codegen-pipeline" not in argv_with_placeholder:
145156
# Default to tile_and_fuse for Fusilli operations.
146157
args.codegen_pipeline = libtuner.CodegenPipelines.llvmgpu_tile_and_fuse
147158

@@ -174,12 +185,12 @@ def load_commands_from_file_or_args(
174185

175186

176187
def build_compile_args(compile_command: str, benchmarks_dir: Path) -> list[str]:
177-
fusilli_compile_flags = shlex.split(compile_command)
188+
fusilli_compile_flags: list[str] = shlex.split(compile_command)
178189

179190
# Start with iree-compile and filter out unwanted flags from fusilli flags.
180191
# See fusilli/include/fusilli/backend/compile_command.h for flag formats.
181-
compile_args = ["iree-compile"]
182-
args_iter = iter(fusilli_compile_flags[1:])
192+
compile_args: list[str] = ["iree-compile"]
193+
args_iter: Iterator[str] = iter(fusilli_compile_flags[1:])
183194
for arg in args_iter:
184195
# Skip output flag (Fusilli generates "-o" as separate argument + path).
185196
if arg == "-o":
@@ -211,62 +222,82 @@ def run_fusilli_benchmark_driver(
211222
) -> None:
212223
# Use --dump to generate MLIR and compile command artifacts, --iter 1 since
213224
# we only need file generation (not actual benchmarking).
214-
cmd = [fusilli_driver, "--dump", "--iter", "1"] + cli_args
225+
cmd: list[str] = [fusilli_driver, "--dump", "--iter", "1"] + cli_args
215226

216227
# Override FUSILLI_CACHE_DIR to control where Fusilli dumps the generated
217228
# MLIR and compilation flags.
218-
env = os.environ.copy()
229+
env: dict[str, str] = os.environ.copy()
219230
env["FUSILLI_CACHE_DIR"] = str(cache_dir)
220231

221232
logging.info(f"> {shlex.join(cmd)}")
222233
logging.info(f" FUSILLI_CACHE_DIR={cache_dir}")
223234

224-
result = subprocess.run(cmd, env=env, capture_output=True, text=True)
225-
226-
# Log stdout even on success for debugging
227-
if result.stdout:
228-
logging.debug(f"Fusilli benchmark driver stdout:\n{result.stdout}")
235+
result: subprocess.CompletedProcess[str] = subprocess.run(
236+
cmd, env=env, capture_output=True, text=True
237+
)
229238

230-
if result.returncode != 0:
231-
logging.error(
232-
f"Fusilli benchmark driver failed with return code {result.returncode}"
233-
)
239+
# Exit early on success.
240+
if result.returncode == 0:
234241
if result.stdout:
235-
logging.error(f"stdout: {result.stdout}")
236-
if result.stderr:
237-
logging.error(f"stderr: {result.stderr}")
238-
raise RuntimeError(
239-
f"Fusilli benchmark driver failed with code {result.returncode}"
240-
)
242+
logging.debug(f"Fusilli benchmark driver stdout:\n{result.stdout}")
243+
return
244+
245+
# Handle failure.
246+
logging.error(
247+
f"Fusilli benchmark driver failed with return code {result.returncode}"
248+
)
249+
if result.stdout:
250+
logging.error(f"stdout: {result.stdout}")
251+
if result.stderr:
252+
logging.error(f"stderr: {result.stderr}")
253+
raise RuntimeError(f"Fusilli benchmark driver failed with code {result.returncode}")
241254

242255

243-
def find_cached_artifacts(cache_dir: Path) -> tuple[Path, Path]:
256+
def find_cached_artifacts(base_dir: Path) -> tuple[Path, Path]:
244257
"""Find source MLIR and compile command from Fusilli cache.
245258
259+
Fusilli cache structure controlled by FUSILLI_CACHE_DIR environment variable:
260+
base_dir/ # User-specified or temp directory (FUSILLI_CACHE_DIR)
261+
.cache/fusilli/ # Fusilli's internal cache structure
262+
<graph_hash>/ # Graph-specific directory (one per operation)
263+
iree-compile-input.mlir # Generated MLIR for the operation
264+
iree-compile-command.txt # Compile command used by Fusilli
265+
266+
Args:
267+
base_dir: The base directory where FUSILLI_CACHE_DIR environment variable was set to.
268+
Fusilli creates its cache at base_dir/.cache/fusilli/
269+
246270
Returns:
247271
Tuple of (source_mlir_path, compile_command_path).
248272
"""
249-
fusilli_cache = cache_dir / ".cache" / "fusilli"
273+
fusilli_cache: Path = base_dir / ".cache" / "fusilli"
250274

251275
if not fusilli_cache.exists():
252276
raise FileNotFoundError(f"Fusilli cache not found at {fusilli_cache}")
253277

254278
# Find the graph directory (there should be exactly one after running
255279
# with --dump).
256-
graph_dirs = list(fusilli_cache.iterdir())
280+
graph_dirs: list[Path] = list(fusilli_cache.iterdir())
257281
if not graph_dirs:
258282
raise FileNotFoundError(f"No graph directories found in {fusilli_cache}")
259283

260-
graph_dir = graph_dirs[0]
284+
graph_dir: Path = graph_dirs[0]
261285

262-
source_mlir_path = graph_dir / "iree-compile-input.mlir"
263-
compile_command_path = graph_dir / "iree-compile-command.txt"
286+
source_mlir_path: Path = graph_dir / "iree-compile-input.mlir"
287+
compile_command_path: Path = graph_dir / "iree-compile-command.txt"
264288

265289
if not source_mlir_path.exists():
266290
raise FileNotFoundError(f"Source MLIR not found at {source_mlir_path}")
267291
if not compile_command_path.exists():
268292
raise FileNotFoundError(f"Compile command not found at {compile_command_path}")
269293

294+
# Validate paths to prevent path traversal attacks.
295+
try:
296+
source_mlir_path.resolve().relative_to(base_dir.resolve())
297+
compile_command_path.resolve().relative_to(base_dir.resolve())
298+
except ValueError as e:
299+
raise ValueError(f"Path traversal detected: {e}")
300+
270301
return source_mlir_path, compile_command_path
271302

272303

@@ -345,12 +376,11 @@ def process_fusilli_command(
345376
# Set up temporary directory.
346377
if args.tmp_dir:
347378
tmp_dir = Path(args.tmp_dir)
348-
if tmp_dir.exists():
349-
logging.warning(f"Removing existing temporary directory: {tmp_dir}")
350-
shutil.rmtree(tmp_dir)
351379
tmp_dir.mkdir(parents=True, exist_ok=True)
352380
logging.info(f"Using user-specified temporary directory: {tmp_dir}")
353381
else:
382+
# Ensure parent directory exists before creating temp directory.
383+
Path("fusilli_tuner").mkdir(exist_ok=True)
354384
tmp_dir = Path(tempfile.mkdtemp(dir="fusilli_tuner", prefix="fusilli_cache_"))
355385
logging.info(f"Created temporary directory: {tmp_dir}")
356386

@@ -445,8 +475,7 @@ def process_fusilli_command(
445475

446476
def main() -> None:
447477
"""Main entry point for the Fusilli tuner."""
448-
parsed_args = parse_args()
449-
args, fusilli_op_args = parsed_args
478+
args, fusilli_op_args = parse_args(sys.argv)
450479

451480
if args.commands_file and fusilli_op_args:
452481
raise ValueError("Cannot specify both --commands-file and --fusilli-args")

0 commit comments

Comments
 (0)