From c12346f0eea8baf09c6340f07510e6667dd8eb00 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:19:32 -0600 Subject: [PATCH 1/8] initial dump --- DEVDOCS.md | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 DEVDOCS.md diff --git a/DEVDOCS.md b/DEVDOCS.md new file mode 100644 index 000000000..bd93110cd --- /dev/null +++ b/DEVDOCS.md @@ -0,0 +1,85 @@ +# Enzyme-JAX Developer Documentation + +## Overview + +Enzyme-JAX is a C++ project that integrates the Enzyme automatic differentiation tool with JAX, enabling automatic differentiation of external C++ code within JAX. The project uses LLVM's MLIR framework for intermediate representation and transformation of code. + +## Building the Project + +### Quick Build +```bash +bazel build --repo_env=CC=clang-18 --color=yes --copt=-fbracket-depth=1024 --host_copt=-fbracket-depth=1024 -c dbg :enzymexlamlir-opt +``` + +### Build Artifacts +- **Main tool**: `enzymexlamlir-opt` (bazel target: `:enzymexlamlir-opt`) + - This is the MLIR optimization tool driver for Enzyme-XLA + - Analogous to `mlir-opt`, drives compiler passes and transformations + - Located in: `src/enzyme_ad/jax/enzymexlamlir-opt.cpp` + +- **Python wheel**: `bazel build :wheel` + +### Generate LSP Support +```bash +bazel run :refresh_compile_commands +``` + +## Project Structure + +### Core Components + +#### 1. **Dialects** (`src/enzyme_ad/jax/Dialect/`) +MLIR dialects define custom operations and types for Enzyme-JAX. + +- **EnzymeXLAOps.td** - Dialect operation definitions + - GPU operations: `kernel_call`, `memcpy`, `gpu_wrapper`, `gpu_block`, `gpu_thread` + - JIT/XLA operations: `jit_call`, `xla_wrapper` + - Linear algebra (BLAS/LAPACK): `symm`, `syrk`, `trmm`, `lu`, `getrf`, `gesvd`, etc. + - Special functions: Bessel functions, GELU, ReLU + - Utility operations: `memref2pointer`, `pointer2memref`, `subindex` + +- **EnzymeXLAAttrs.td** - Custom attribute definitions (LAPACK enums, etc.) + +#### 2. **Passes** (`src/enzyme_ad/jax/Passes/`) +MLIR passes implement transformations and optimizations. + +- Tablegen definitions in `src/enzyme_ad/jax/Passes/Passes.td` +- **EnzymeHLOOpt.cpp** - Core optimization patterns for StableHLO and EnzymeXLA operations + +#### 3. **Transform Operations** (`src/enzyme_ad/jax/TransformOps/`) +In order to have more granular control over which pattern is applied, patterns are also registered as transform operations. +For example: +``` +def AndPadPad : EnzymeHLOPatternOp< + "and_pad_pad"> { + let patterns = ["AndPadPad"]; +} +``` +Exposes the `AndPadPad` pattern (defined in `EnzymeHLOOpt.cpp`) to `enzymexlamlir-opt`, so it can be used as: +``` +enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=and_pad_pad" --transform-interpreter --enzyme-hlo-remove-transform -allow-unregistered-dialect input.mlir +``` + +## Common Development Tasks + +### Adding a New Optimization Pattern + +1. Define the pattern class in `src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp` +2. Inherit from `mlir::OpRewritePattern` +3. Implement `matchAndRewrite()` method +4. Register in `EnzymeHLOOptPass::runOnOperation()` +5. Register as Transform operation in `TransformOps.td` + +### Adding a New Dialect Operation + +1. Define operation in `src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td` +2. Specify arguments, results, and traits +3. Implement operation class if needed in `src/enzyme_ad/jax/Dialect/Ops.cpp` +4. TODO: write about derivative rules? + +## Testing + +Run tests with: +```bash +bazel test //test/... +``` From a7895e11510bdf2d4eebcda762bcbd726fd3670a Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:44:46 -0600 Subject: [PATCH 2/8] remove fbracked_depth from example command --- DEVDOCS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DEVDOCS.md b/DEVDOCS.md index bd93110cd..18d0edbd6 100644 --- a/DEVDOCS.md +++ b/DEVDOCS.md @@ -8,7 +8,7 @@ Enzyme-JAX is a C++ project that integrates the Enzyme automatic differentiation ### Quick Build ```bash -bazel build --repo_env=CC=clang-18 --color=yes --copt=-fbracket-depth=1024 --host_copt=-fbracket-depth=1024 -c dbg :enzymexlamlir-opt +bazel build --repo_env=CC=clang-18 --color=yes -c dbg :enzymexlamlir-opt ``` ### Build Artifacts From 0557e0512edf70a87ec5fecbcb3447e050c608a3 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:46:51 -0600 Subject: [PATCH 3/8] move updated project overview to README Co-authored-by: Roman Lee --- DEVDOCS.md | 6 +----- README.md | 5 +---- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/DEVDOCS.md b/DEVDOCS.md index 18d0edbd6..6b589df47 100644 --- a/DEVDOCS.md +++ b/DEVDOCS.md @@ -1,8 +1,4 @@ -# Enzyme-JAX Developer Documentation - -## Overview - -Enzyme-JAX is a C++ project that integrates the Enzyme automatic differentiation tool with JAX, enabling automatic differentiation of external C++ code within JAX. The project uses LLVM's MLIR framework for intermediate representation and transformation of code. +# Enzyme-JAX Developer Notes ## Building the Project diff --git a/README.md b/README.md index b92219dd9..ae92f3574 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,6 @@ # Enzyme-JAX -Custom bindings for Enzyme automatic differentiation tool and interfacing with -JAX. Currently this is set up to allow you to automatically import, and -automatically differentiate (both jvp and vjp) external C++ code into JAX. As -Enzyme is language-agnostic, this can be extended for arbitrary programming +Enzyme-JAX is a C++ project whose original aim was to integrate the Enzyme automatic differentiation tool [1] with JAX, enabling automatic differentiation of external C++ code within JAX. It has since expanded to incorporate Polygeist's [2] high performance raising, parallelization, cross compilation workflow, as well as numerous tensor, linear algerba, and communication optimizations. The project uses LLVM's MLIR framework for intermediate representation and transformation of code. As Enzyme is language-agnostic, this can be extended for arbitrary programming languages (Julia, Swift, Fortran, Rust, and even Python)! You can use From 7a8850362b3df1358a37781b94c865dee21a5888 Mon Sep 17 00:00:00 2001 From: Roman Lee <31547765+romanlee@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:38:57 -0800 Subject: [PATCH 4/8] Move references from devdocs to readme --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index ae92f3574..2cb257948 100644 --- a/README.md +++ b/README.md @@ -74,3 +74,8 @@ Enzyme-Jax exposes a bunch of different tensor rewrites as MLIR passes in `src/e ```bash bazel run :refresh_compile_commands ``` + +# References +[1] Moses, William, and Valentin Churavy. "Instead of rewriting foreign code for machine learning, automatically synthesize fast gradients." Advances in neural information processing systems 33 (2020): 12472-12485. + +[2] Moses, William S., et al. "Polygeist: Raising C to polyhedral MLIR." 2021 30th International Conference on Parallel Architectures and Compilation Techniques (PACT). IEEE, 2021. From 9f3429973f87d3f170a8913f511a19c8f25160d5 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 5 Dec 2025 11:02:21 -0600 Subject: [PATCH 5/8] Some more information on LIT tests --- DEVDOCS.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/DEVDOCS.md b/DEVDOCS.md index 6b589df47..185dc5e38 100644 --- a/DEVDOCS.md +++ b/DEVDOCS.md @@ -79,3 +79,14 @@ Run tests with: ```bash bazel test //test/... ``` +This runs the tests in + +Most of the Enzyme-JaX tests use [lit](https://llvm.org/docs/CommandGuide/lit.html) for testing. +These tests are stored in `test/lit_tests`. +A lit test contains one or more run directives at the top a file. +e.g. in `test/lit_tests/if.mlir`: +```mlir +// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s +``` +This instructs `lit` to run the `enzyme-hlo-opt` pass on `test/lit_tests/if.mlir`. +The output is fed to `FileCheck` which compares it against the expected result that is provided in comments in the file that start with `// CHECK`. From 6344e6a7c32a7459cc934fbe1f6d33cf502886e2 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:49:09 -0600 Subject: [PATCH 6/8] mention primitives.py --- DEVDOCS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/DEVDOCS.md b/DEVDOCS.md index 185dc5e38..f74bc0c75 100644 --- a/DEVDOCS.md +++ b/DEVDOCS.md @@ -65,6 +65,7 @@ enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=and_pad_pad" --transform-in 3. Implement `matchAndRewrite()` method 4. Register in `EnzymeHLOOptPass::runOnOperation()` 5. Register as Transform operation in `TransformOps.td` +6. Add the pass to the appropriate pass list in `src/enzyme_ad/jax/primitives.py` ### Adding a New Dialect Operation From 20c31f0def49b7283999211087f7f5f83c750795 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:52:16 -0600 Subject: [PATCH 7/8] enzymehloopt.cpp contains all stablehlo optimizations --- DEVDOCS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/DEVDOCS.md b/DEVDOCS.md index f74bc0c75..ede189cfa 100644 --- a/DEVDOCS.md +++ b/DEVDOCS.md @@ -41,6 +41,7 @@ MLIR passes implement transformations and optimizations. - Tablegen definitions in `src/enzyme_ad/jax/Passes/Passes.td` - **EnzymeHLOOpt.cpp** - Core optimization patterns for StableHLO and EnzymeXLA operations + This file contains (nearly) all the stablehlo tensor optimizations. #### 3. **Transform Operations** (`src/enzyme_ad/jax/TransformOps/`) In order to have more granular control over which pattern is applied, patterns are also registered as transform operations. From 416514481857ac23d1f187a4a02300804eed1827 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:47:47 -0600 Subject: [PATCH 8/8] Adding a new lowering pass Co-authored-by: Roman Lee <31547765+romanlee@users.noreply.github.com> --- DEVDOCS.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/DEVDOCS.md b/DEVDOCS.md index ede189cfa..117cc4a95 100644 --- a/DEVDOCS.md +++ b/DEVDOCS.md @@ -68,6 +68,13 @@ enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=and_pad_pad" --transform-in 5. Register as Transform operation in `TransformOps.td` 6. Add the pass to the appropriate pass list in `src/enzyme_ad/jax/primitives.py` +### Adding a new lowering pass +1. Define the pass in `src/enzyme_ad/jax/Passes/Passes.td`, e.g. `LowerEnzymeXLALinalgPass` +2. Create a new `.cpp` file in `src/enzyme_ad/jax/Passes/`, e.g. `src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp`. In the new file... + 1. Inherit from `mlir::OpRewritePattern` and implement the `matchAndRewrite()` method. + 2. Inherit from the generated `PassBase` class and implement `runOnOperation` to register your pass. +3. Write lit tests for your pass, e.g. `test/lit_tests/linalg/*.mlir`. + ### Adding a New Dialect Operation 1. Define operation in `src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td`