Skip to content

Commit fe11272

Browse files
committed
use jax config for debug_symbols
1 parent d4ecfa1 commit fe11272

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

jax_rocm_plugin/build/build.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,12 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
245245
default="",
246246
help="Path to the ROCm toolkit.",
247247
)
248+
rocm_group.add_argument(
249+
"--preserve_debug_symbols",
250+
type=bool,
251+
default=True,
252+
help="Preserve debug symbols in the generated SO files",
253+
)
248254

249255
# Compile Options
250256
compile_group = parser.add_argument_group("Compile Options")
@@ -598,6 +604,8 @@ async def main():
598604
)
599605

600606
if "rocm" in args.wheels:
607+
if args.preserve_debug_symbols:
608+
wheel_build_command_base.append("--config=debug_symbols")
601609
wheel_build_command_base.append("--config=rocm_base")
602610
if args.use_clang:
603611
wheel_build_command_base.append("--config=rocm")

jax_rocm_plugin/build/rocm/tools/build_wheels.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def build_jaxlib_wheel(
159159
"python",
160160
"build/build.py",
161161
"build",
162-
"--bazel_options=--config=debug_symbols",
163162
"--wheels=jax-rocm-plugin,jax-rocm-pjrt",
164163
"--rocm_path=%s" % rocm_path,
165164
"--rocm_version=%s" % version_string,

0 commit comments

Comments
 (0)