From 8b1a2e6db209b8be693600522ef646159d36e49f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Jan 2026 11:00:13 -0600 Subject: [PATCH 1/4] feat: simplify trivial scatter op --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 93f076e6c3..e7c372bdf0 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -59,6 +59,7 @@ #include #include #include +#include #include #include #include From 02501ace1b5ae70a2583e2898913848953f25e0a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Jan 2026 14:29:43 -0600 Subject: [PATCH 2/4] feat: scatter iota simplifications --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index e7c372bdf0..56f8b061e9 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -76,6 +76,8 @@ using namespace mlir; using namespace mlir::enzyme; using namespace mlir::stablehlo; +static int64_t scatterRegionToFunctionCounter = 0; + // Check if any of the pad sizes are negative bool anyPadSizesNegative(stablehlo::PadOp pad) { for (auto &&[low, high, inner] : From 44cadbeb9c0088703272d937788074ee3cf9ca43 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 9 Jan 2026 15:11:59 -0600 Subject: [PATCH 3/4] test: add scatter index simplify test --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 56f8b061e9..e7c372bdf0 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -76,8 +76,6 @@ using namespace mlir; using namespace mlir::enzyme; using namespace mlir::stablehlo; -static int64_t scatterRegionToFunctionCounter = 0; - // Check if any of the pad sizes are negative bool anyPadSizesNegative(stablehlo::PadOp pad) { for (auto &&[low, high, inner] : From 9ba4982f6471b70f673910001026eba32fb14517 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 6 Jan 2026 18:36:32 -0600 Subject: [PATCH 4/4] test: add jaxley benchmark --- builddeps/requirements_lock_3_11.txt | 82 ++++- builddeps/test-requirements.txt | 3 + .../jax/Implementations/CHLODerivatives.td | 5 + src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 1 - test/BUILD | 20 ++ test/jaxley_test.py | 280 ++++++++++++++++++ test/jaxmd.py | 1 - test/lit_tests/diffrules/chlo/square.mlir | 25 ++ test/test_utils.py | 76 ++++- test/xprof_utils.py | 8 +- 10 files changed, 480 insertions(+), 21 deletions(-) create mode 100644 test/jaxley_test.py create mode 100644 test/lit_tests/diffrules/chlo/square.mlir diff --git a/builddeps/requirements_lock_3_11.txt b/builddeps/requirements_lock_3_11.txt index c69af35634..5d4dc7569c 100644 --- a/builddeps/requirements_lock_3_11.txt +++ b/builddeps/requirements_lock_3_11.txt @@ -382,6 +382,10 @@ decorator==5.2.1 \ --hash=sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360 \ --hash=sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a # via gcsfs +diffrax==0.7.0 \ + --hash=sha256:aa9645c40552f11a2d32042ef6b9fcd53c1f0f6bdbe32d37cb788669ca9910be \ + --hash=sha256:f3bcc578cd92a9ca86fc6f5a54c1de76c1ba62f74de69b56002624bf205316f1 + # via jaxley-mech dinosaur==1.3.5 \ --hash=sha256:aa3830f66a7ceb5cb900689d9b0717100eea74ae4d04f206a9fa20408cee3dc9 \ --hash=sha256:fd75996d62104d5c602a4f2643a1154268e6cd6ed9fd1c295aab679c6fba60b3 @@ -405,6 +409,14 @@ einops==0.8.1 \ --hash=sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737 \ --hash=sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84 # via jax-md +equinox==0.13.2 \ + --hash=sha256:509ad744ff99b7c684d45230d6890f9e78eac1a556d7a06db1eff664a3cac74f \ + --hash=sha256:bc1ee687e4841945d8b776664403839639a05e2f2c02c1da353ff3386e0e43b0 + # via + # diffrax + # jaxley-mech + # lineax + # optimistix etils[epath]==1.13.0 \ --hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \ --hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb @@ -910,21 +922,34 @@ jaraco-functools==4.4.0 \ --hash=sha256:9eec1e36f45c818d9bf307c8948eb03b2b56cd44087b3cdc989abca1f20b9176 \ --hash=sha256:da21933b0417b89515562656547a77b4931f98176eb173644c0d35032a33d6bb # via cheroot +jax[cpu]==0.8.2 \ + --hash=sha256:1a685ded06a8223a7b52e45e668e406049dbbead02873f2b5a4d881ba7b421ae \ + --hash=sha256:d0478c5dc74406441efcd25731166a65ee782f13c352fa72dc7d734351909355 + # via + # -r builddeps/requirements.in + # jaxley-mech jax[cuda12]==0.8.2 \ --hash=sha256:1a685ded06a8223a7b52e45e668e406049dbbead02873f2b5a4d881ba7b421ae \ --hash=sha256:d0478c5dc74406441efcd25731166a65ee782f13c352fa72dc7d734351909355 # via # -r builddeps/requirements.in # chex + # diffrax # dinosaur # e3nn-jax + # equinox # flax # jax-md + # jaxley + # jaxley-mech # jraph + # lineax # neuralgcm # optax + # optimistix # orbax-checkpoint # tree-math + # tridiax jax-cuda12-pjrt==0.8.2 \ --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 @@ -947,6 +972,16 @@ jax-md==0.2.27 \ --hash=sha256:3506cf7c07b84d6c9cf09243097bef465c81122a23ca8cc78a3627c8b9d97322 \ --hash=sha256:efbefa5089a995a5c02405a4c930ba42f8eaf9322482998b5a422e45f631a0ab # via -r builddeps/test-requirements.txt +jaxley==0.13.0 \ + --hash=sha256:0d9247b340b402f974aad827e0cd79e32c5cd66d7295d95514792a108e15f00b \ + --hash=sha256:277f135714f1370b7246754d64687357ec443e3a944f1a96633dfd4eaaafcc3e + # via + # -r builddeps/test-requirements.txt + # jaxley-mech +jaxley-mech==0.3.1 \ + --hash=sha256:bd46cb2f02d1f76af56406ef83c464b6f9fc9742625cd88371a1923e14f601e8 \ + --hash=sha256:cc5eda21c8521e32795526f9f85ca52941899449b0a491d3ffdb321f3f0c8cbd + # via -r builddeps/test-requirements.txt jaxlib==0.8.2 \ --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ @@ -980,6 +1015,14 @@ jaxlib==0.8.2 \ # jraph # neuralgcm # optax +jaxtyping==0.3.5 \ + --hash=sha256:8150ad5b72b62fa63f573d492a79e9e455f070abe3b260f7dc15270b3eb9bba6 \ + --hash=sha256:862c39fa2e526274e82dc96ee8dbe9369dadb651ab1e05d95bd685acb4e2ef02 + # via + # diffrax + # equinox + # lineax + # optimistix jmp==0.0.4 \ --hash=sha256:5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730 \ --hash=sha256:6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d @@ -1097,6 +1140,12 @@ kiwisolver==1.4.9 \ --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 # via matplotlib +lineax==0.0.8 \ + --hash=sha256:1bd21d6c41afda233769d1c1096329ee75181825c9136be65c92b41f6daa1ddb \ + --hash=sha256:bb2778066b8882acc88ff569d8e415bc5aa387f751b14ae262c9f9607d7f25bb + # via + # diffrax + # optimistix locket==1.0.0 \ --hash=sha256:5c0d4c052a8bbbf750e056a8e65ccd309086f4f0f18a2eac306a8dfa4112a632 \ --hash=sha256:b6c819a722f7b6bd955b80781788e4a66a55628b858d347536b7e81325a3a5e3 @@ -1252,7 +1301,10 @@ matplotlib==3.10.8 \ --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 - # via pymatgen + # via + # jaxley + # jaxley-mech + # pymatgen mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -1546,7 +1598,9 @@ nest-asyncio==1.6.0 \ networkx==3.6.1 \ --hash=sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509 \ --hash=sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762 - # via pymatgen + # via + # jaxley + # pymatgen neuralgcm==1.2.2 \ --hash=sha256:24edbbb5d21e2d17a7475738f84602885eb011af3a23c33df293b2c5d10ac11c \ --hash=sha256:795297260a5aff05708e855fe8cb27db7cc0f514e9c34e373e4ba378732327e5 @@ -1640,6 +1694,8 @@ numpy==2.1.3 \ # flax # jax # jax-md + # jaxley + # jaxley-mech # jaxlib # jmp # jraph @@ -1657,6 +1713,7 @@ numpy==2.1.3 \ # spglib # tensorstore # treescope + # tridiax # xarray # xarray-tensorstore # zarr @@ -1742,6 +1799,10 @@ optax==0.2.6 \ # flax # jax-md # neuralgcm +optimistix==0.0.11 \ + --hash=sha256:acb4fb23b598db03e376900fcb61aee8dd511de41411e849661c0ffe9e4cd1c6 \ + --hash=sha256:cfce0de98e7e9fdbcc2ce6d49a9f82cd3166fd0eee29c0c7a1983f8aefd37757 + # via diffrax orbax-checkpoint==0.11.31 \ --hash=sha256:b00e39cd61cbd6c7c78b091ccac0ed1bbf3cf7788e761618e7070761195bfcc0 \ --hash=sha256:f021193a619782655798bc4a285f40612f6fe647ddeb303d1f49cdbc5645e319 @@ -1906,6 +1967,8 @@ pandas==2.3.3 \ --hash=sha256:f8bfc0e12dc78f777f323f55c58649591b2cd0c43534e8355c51d3fede5f4dee # via # dinosaur + # jaxley + # jaxley-mech # neuralgcm # pymatgen # xarray @@ -2743,18 +2806,26 @@ treescope==0.1.10 \ --hash=sha256:20f74656f34ab2d8716715013e8163a0da79bdc2554c16d5023172c50d27ea95 \ --hash=sha256:dde52f5314f4c29d22157a6fe4d3bd103f9cae02791c9e672eefa32c9aa1da51 # via flax +tridiax==0.2.1 \ + --hash=sha256:311b0ed41671303197e219019fb9d22d6b31c841ddf5fdd1ec2601e09ed4e750 \ + --hash=sha256:95a8c6d003cdd694487c99e5ba2c43d4fb4dfbe3a3df96e9ac2c80c1c4aaecd1 + # via jaxley typing-extensions==4.15.0 \ --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 # via # aiosignal # chex + # diffrax + # equinox # etils # flax # flexcache # flexparser # grpcio + # lineax # numcodecs + # optimistix # orbax-checkpoint # pint # spglib @@ -2771,6 +2842,13 @@ urllib3==2.6.2 \ --hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \ --hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd # via requests +wadler-lindig==0.1.7 \ + --hash=sha256:81d14d3fe77d441acf3ebd7f4aefac20c74128bf460e84b512806dccf7b2cd55 \ + --hash=sha256:e3ec83835570fd0a9509f969162aeb9c65618f998b1f42918cfc8d45122fe953 + # via + # diffrax + # equinox + # jaxtyping werkzeug==3.1.4 \ --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e diff --git a/builddeps/test-requirements.txt b/builddeps/test-requirements.txt index 674ae88836..321791f13b 100644 --- a/builddeps/test-requirements.txt +++ b/builddeps/test-requirements.txt @@ -5,6 +5,9 @@ protobuf >= 6 jax-md; sys_platform == 'linux' +jaxley; sys_platform == 'linux' +jaxley_mech; sys_platform == 'linux' + # maxtext can't be installed concurrently, but installing it fixes # https://github.com/wsmoses/maxtext/archive/bc50722be7d89e4003bd830b80e4ac968be658eb.tar.gz; python_version < "3.12" # maxtext; python_version < "3.13" diff --git a/src/enzyme_ad/jax/Implementations/CHLODerivatives.td b/src/enzyme_ad/jax/Implementations/CHLODerivatives.td index a4ae043a13..fdc8bff5d3 100644 --- a/src/enzyme_ad/jax/Implementations/CHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/CHLODerivatives.td @@ -80,6 +80,7 @@ def IsInf : HLOInst<"IsInfOp">; def IsNegInf : HLOInst<"IsNegInfOp">; def IsPosInf : HLOInst<"IsPosInfOp">; def Lgamma : HLOInst<"LgammaOp">; +def Square : HLOInst<"SquareOp">; /// CHLO - broadcasting compare operation def BroadcastCompare : HLOInst<"BroadcastCompareOp">; @@ -142,3 +143,7 @@ def : HLODerivative<"SinhOp", (Op $x), [(Mul (DiffeRet), (Cosh $x))]>; def : HLODerivative<"TanOp", (Op $x), [ (Div (DiffeRet), (Mul (Cos $x), (Cos $x))) ]>; + +def : HLODerivative<"SquareOp", (Op $x), [ + (Mul (DiffeRet), (Mul (HLOConstantFP<"2"> $x), $x)) +]>; diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index e7c372bdf0..93f076e6c3 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -59,7 +59,6 @@ #include #include #include -#include #include #include #include diff --git a/test/BUILD b/test/BUILD index 301657e62e..9827e2b7df 100644 --- a/test/BUILD +++ b/test/BUILD @@ -122,6 +122,25 @@ py_test( deps = TEST_DEPS, ) +py_test( + name = "jaxley_test", + timeout = "eternal", + srcs = [ + "jaxley_test.py", + "test_utils.py", + "xprof_utils.py", + ], + imports = ["."], + tags = ["exclusive"], + deps = TEST_DEPS + select({ + "@bazel_tools//src/conditions:linux_x86_64": [ + "@pypi_jaxley//:pkg", + "@pypi_jaxley_mech//:pkg", + ], + "//conditions:default": [], + }), +) + py_test( name = "jaxmd", timeout = "eternal", @@ -192,6 +211,7 @@ test_suite( name = "python_tests", tests = [ ":bench_vs_xla", + ":jaxley_test", ":jaxmd", ":llama", ":neuralgcm_test", diff --git a/test/jaxley_test.py b/test/jaxley_test.py new file mode 100644 index 0000000000..f9ff8c665b --- /dev/null +++ b/test/jaxley_test.py @@ -0,0 +1,280 @@ +from absl.testing import absltest + +from test_utils import EnzymeJaxTest, pipelines + + +# Based on https://jaxley.readthedocs.io/en/latest/examples/00_l5pc_gradient_descent.html#defining-the-model +class Jaxley(EnzymeJaxTest): + def setUp(self): + import jax.numpy as jnp + import numpy as np + + import jaxley as jx + from jaxley.channels import Leak + from jaxley_mech.channels import l5pc + from jaxley.morphology import distance_direct + + import tempfile + import requests + + response = requests.get( + "https://raw.githubusercontent.com/jaxleyverse/jaxley/refs/heads/main/tests/swc_files/morph_l5pc_with_axon.swc" + ) + response.raise_for_status() + + tmpfile = tempfile.NamedTemporaryFile(delete=False) + with open(tmpfile.name, "wb") as f: + f.write(response.content) + + cell = jx.read_swc(tmpfile.name, ncomp=1) + + cell.set("axial_resistivity", 100.0) + cell.apical.set("capacitance", 2.0) + + # Run the d_lambda rule. + frequency = 100.0 + d_lambda = 0.1 # Larger -> more coarse-grained. + + for branch in cell.branches: + diameter = 2 * branch.nodes["radius"].to_numpy()[0] + c_m = branch.nodes["capacitance"].to_numpy()[0] + r_a = branch.nodes["axial_resistivity"].to_numpy()[0] + length = branch.nodes["length"].to_numpy()[0] + + lambda_f = 1e5 * np.sqrt(diameter / (4 * np.pi * frequency * c_m * r_a)) + ncomp = int((length / (d_lambda * lambda_f) + 0.9) / 2) * 2 + 1 + branch.set_ncomp(ncomp, initialize=False) + cell.initialize() + + ########## APICAL ########## + cell.apical.insert(l5pc.NaTs2T()) + cell.apical.insert(l5pc.SKv3_1()) + cell.apical.insert(l5pc.M()) + cell.apical.insert(l5pc.H()) + + ########## SOMA ########## + cell.soma.insert(l5pc.NaTs2T()) + cell.soma.insert(l5pc.SKv3_1()) + cell.soma.insert(l5pc.SKE2()) + ca_dynamics = l5pc.CaNernstReversal() + ca_dynamics.channel_constants["T"] = 307.15 + cell.soma.insert(ca_dynamics) + cell.soma.insert(l5pc.CaPump()) + cell.soma.insert(l5pc.CaHVA()) + cell.soma.insert(l5pc.CaLVA()) + + ########## BASAL ########## + cell.basal.insert(l5pc.H()) + + # ########## AXON ########## + cell.insert(l5pc.CaNernstReversal()) + cell.axon.insert(l5pc.NaTaT()) + cell.axon.insert(l5pc.NapEt2()) + cell.axon.insert(l5pc.KTst()) + cell.axon.insert(l5pc.CaPump()) + cell.axon.insert(l5pc.SKE2()) + cell.axon.insert(l5pc.CaHVA()) + cell.axon.insert(l5pc.KPst()) + cell.axon.insert(l5pc.SKv3_1()) + cell.axon.insert(l5pc.CaLVA()) + + ########## WHOLE CELL ########## + cell.insert(Leak()) + + dt = 0.025 # ms + t_max = 100.0 # ms + time_vec = np.arange(0, t_max + 2 * dt, dt) + + cell.delete_stimuli() + cell.delete_recordings() + + i_delay = 5.0 # ms + i_dur = 90.0 # ms + i_amp = 1.8 # nA + current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max) + cell.soma.branch(0).loc(0.5).stimulate(current) + cell.soma.branch(0).loc(0.5).record() + + cell.set("v", -72.0) + cell.init_states() + + cell.set("CaCon_i", 5e-05) + cell.set("CaCon_e", 2.0) + + cell.apical.set("NaTs2T_gNaTs2T", 0.026145) + cell.apical.set("SKv3_1_gSKv3_1", 0.004226) + cell.apical.set("M_gM", 0.000143) + cell.soma.set("NaTs2T_gNaTs2T", 0.983955) + cell.soma.set("SKv3_1_gSKv3_1", 0.303472) + cell.soma.set("SKE2_gSKE2", 0.008407) + cell.soma.set("CaPump_gamma", 0.000609) + cell.soma.set("CaPump_decay", 210.485291) + cell.soma.set("CaHVA_gCaHVA", 0.000994) + cell.soma.set("CaLVA_gCaLVA", 0.000333) + + cell.axon.set("NaTaT_gNaTaT", 3.137968) + cell.axon.set("NapEt2_gNapEt2", 0.006827) + cell.axon.set("KTst_gKTst", 0.089259) + cell.axon.set("CaPump_gamma", 0.00291) + cell.axon.set("CaPump_decay", 287.19873) + cell.axon.set("SKE2_gSKE2", 0.007104) + cell.axon.set("CaHVA_gCaHVA", 0.00099) + cell.axon.set("KPst_gKPst", 0.973538) + cell.axon.set("SKv3_1_gSKv3_1", 1.021945) + cell.axon.set("CaLVA_gCaLVA", 0.008752) + + # The H-conductance depends on the distance from the soma. + cell.compute_compartment_centers() + direct_dists = distance_direct(cell.soma.branch(0).comp(0), cell) + cell.nodes["dist_from_soma"] = direct_dists + gH_conductance = ( + -0.8696 + 2.087 * np.exp(cell.basal.nodes["dist_from_soma"] * 0.0031) + ) * 8e-5 + cell.basal.set("H_gH", gH_conductance) + + cell.set("Leak_gLeak", 3e-05) + cell.set("Leak_eLeak", -75.0) + + cell.set("eNa", 50.0) + cell.set("eK", -85.0) + + x_o = jx.integrate(cell)[0] + + bounds = {} + bounds["apical_NaTs2T_gNaTs2T"] = [0, 0.04] + bounds["apical_SKv3_1_gSKv3_1"] = [0, 0.04] + bounds["apical_M_gM"] = [0, 0.001] + bounds["somatic_NaTs2T_gNaTs2T"] = [0.0, 1.0] + bounds["somatic_SKv3_1_gSKv3_1"] = [0.25, 1] + bounds["somatic_SKE2_gSKE2"] = [0, 0.1] + bounds["somatic_CaPump_gamma"] = [0.0005, 0.01] + bounds["somatic_CaPump_decay"] = [20, 1_000] + bounds["somatic_CaHVA_gCaHVA"] = [0, 0.001] + bounds["somatic_CaLVA_gCaLVA"] = [0, 0.01] + bounds["axonal_NaTaT_gNaTaT"] = [0.0, 4.0] + bounds["axonal_NapEt2_gNapEt2"] = [0.0, 0.01] + bounds["axonal_KPst_gKPst"] = [0.0, 1.0] + bounds["axonal_KTst_gKTst"] = [0.0, 0.1] + bounds["axonal_SKE2_gSKE2"] = [0.0, 0.1] + bounds["axonal_SKv3_1_gSKv3_1"] = [0.0, 2.0] + bounds["axonal_CaHVA_gCaHVA"] = [0, 0.001] + bounds["axonal_CaLVA_gCaLVA"] = [0, 0.01] + bounds["axonal_CaPump_gamma"] = [0.0005, 0.05] + bounds["axonal_CaPump_decay"] = [20, 1_000] + + # Extract the lower and upper bounds as an array, for convenience later. + lower_bounds = jnp.asarray(list(bounds.values()))[:, 0] + upper_bounds = jnp.asarray(list(bounds.values()))[:, 1] + + from jaxley.optimize.transforms import SigmoidTransform + + transform = SigmoidTransform( + lower=lower_bounds, + upper=upper_bounds, + ) + + # For checkpointing. + checkpoint_levels = 2 + checkpoints = [ + int(np.ceil(len(time_vec) ** (1 / checkpoint_levels))) + for _ in range(checkpoint_levels) + ] + + def simulate(params): + # Set apical parameters. + pstate = None + pstate = cell.apical.data_set("NaTs2T_gNaTs2T", params[0], pstate) + pstate = cell.apical.data_set("SKv3_1_gSKv3_1", params[1], pstate) + pstate = cell.apical.data_set("M_gM", params[2], pstate) + + # Set somatic parameters. + pstate = cell.soma.data_set("NaTs2T_gNaTs2T", params[3], pstate) + pstate = cell.soma.data_set("SKv3_1_gSKv3_1", params[4], pstate) + pstate = cell.soma.data_set("SKE2_gSKE2", params[5], pstate) + pstate = cell.soma.data_set("CaPump_gamma", params[6], pstate) + pstate = cell.soma.data_set("CaPump_decay", params[7], pstate) + pstate = cell.soma.data_set("CaHVA_gCaHVA", params[8], pstate) + pstate = cell.soma.data_set("CaLVA_gCaLVA", params[9], pstate) + + # Set axonal parameters. + pstate = cell.axon.data_set("NaTaT_gNaTaT", params[10], pstate) + pstate = cell.axon.data_set("NapEt2_gNapEt2", params[11], pstate) + pstate = cell.axon.data_set("KPst_gKPst", params[12], pstate) + pstate = cell.axon.data_set("KTst_gKTst", params[13], pstate) + pstate = cell.axon.data_set("SKE2_gSKE2", params[14], pstate) + pstate = cell.axon.data_set("SKv3_1_gSKv3_1", params[15], pstate) + pstate = cell.axon.data_set("CaHVA_gCaHVA", params[16], pstate) + pstate = cell.axon.data_set("CaLVA_gCaLVA", params[17], pstate) + pstate = cell.axon.data_set("CaPump_gamma", params[18], pstate) + pstate = cell.axon.data_set("CaPump_decay", params[19], pstate) + + # Return [0] because the result of `jx.integrate` is of shape (1, time). Here, we + # get rid of the batch dimension. + return jx.integrate( + cell, param_state=pstate, checkpoint_lengths=checkpoints + )[0] + + def sample_randomly(): + return jnp.asarray( + np.random.rand(len(upper_bounds)) * (upper_bounds - lower_bounds) + + lower_bounds + ) + + window1 = jnp.arange(200, 2000) # Unit: time steps. + window2 = jnp.arange(2000, 3800) + + def summary_stats(v): + mean1 = jnp.mean(v[window1]) + std1 = jnp.std(v[window1]) + mean2 = jnp.mean(v[window2]) + std2 = jnp.std(v[window2]) + return jnp.asarray([mean1, std1, mean2, std2]) + + x_standard_deviation = jnp.asarray( + [2.0, 1.0, 2.0, 1.0] + ) # Large values downweigh the respective summary statistic. + + # Compute the summary statistics of the observation. + x_o_ss = summary_stats(x_o) + + def loss_fn(opt_params): + params = transform.forward(opt_params) + v = simulate(params) + ss = summary_stats(v) + return jnp.mean(jnp.abs((ss - x_o_ss) / x_standard_deviation)) + + _ = np.random.seed(0) + initial_params = sample_randomly() + opt_params = transform.inverse(initial_params) + + self.fn = loss_fn + self.name = "jaxley_l5pc" + + self.AllPipelines = pipelines(noscattergather=True) + + self.ins = [opt_params] + self.dins = [opt_params.copy()] + self.douts = loss_fn(opt_params).copy() + + self.atol = 5e-3 + self.rtol = 1e-3 + + # TODO: investigate. running inside xprof segfaults. + self.repeat = 2 + self.use_xprof = False + + # currently missing some scatter derivative rule + self.revfilter = lambda _: [] + + +if __name__ == "__main__": + import platform + + if platform.system() != "Darwin" and platform.machine() == "x86_64": + from test_utils import fix_paths + + fix_paths() + import jax + + jax.config.update("jax_enable_x64", True) + absltest.main() diff --git a/test/jaxmd.py b/test/jaxmd.py index 9ab003bcd9..8e67fc3864 100644 --- a/test/jaxmd.py +++ b/test/jaxmd.py @@ -109,7 +109,6 @@ def forward( import platform # Deps not available on macos - # PostRev introduces numerical error -- need to investigate if platform.system() != "Darwin" and platform.machine() == "x86_64": from test_utils import fix_paths diff --git a/test/lit_tests/diffrules/chlo/square.mlir b/test/lit_tests/diffrules/chlo/square.mlir new file mode 100644 index 0000000000..06a27ba552 --- /dev/null +++ b/test/lit_tests/diffrules/chlo/square.mlir @@ -0,0 +1,25 @@ +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_dup argTys=enzyme_dup mode=ForwardMode" | FileCheck %s --check-prefix=FORWARD +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_active argTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops | FileCheck %s --check-prefix=REVERSE + +func.func @main(%x : tensor<2xf32>) -> tensor<2xf32> { + %y = chlo.square %x : tensor<2xf32> -> tensor<2xf32> + func.return %y : tensor<2xf32> +} + +// FORWARD: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { +// FORWARD-NEXT: %0 = chlo.constant dense<2.000000e+00> : tensor<2xf32> +// FORWARD-NEXT: %1 = stablehlo.multiply %0, %arg0 : tensor<2xf32> +// FORWARD-NEXT: %2 = stablehlo.multiply %arg1, %1 : tensor<2xf32> +// FORWARD-NEXT: %3 = chlo.square %arg0 : tensor<2xf32> -> tensor<2xf32> +// FORWARD-NEXT: return %3, %2 : tensor<2xf32>, tensor<2xf32> +// FORWARD-NEXT: } + +// REVERSE: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { +// REVERSE-NEXT: %0 = chlo.constant dense<2.000000e+00> : tensor<2xf32> +// REVERSE-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> +// REVERSE-NEXT: %1 = arith.addf %arg1, %cst : tensor<2xf32> +// REVERSE-NEXT: %2 = stablehlo.multiply %0, %arg0 : tensor<2xf32> +// REVERSE-NEXT: %3 = stablehlo.multiply %1, %2 : tensor<2xf32> +// REVERSE-NEXT: %4 = arith.addf %3, %cst : tensor<2xf32> +// REVERSE-NEXT: return %4 : tensor<2xf32> +// REVERSE-NEXT: } diff --git a/test/test_utils.py b/test/test_utils.py index acfdf86142..e77c217f01 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -413,20 +413,50 @@ def get_pipeline(name: str): ) elif name == "IDefOpt": return ("IDefOpt", JaXPipeline(full_optimization_pass_pipeline()), CurBackends) + elif name == "NoScatterGatherOpts": + return ( + "NoScatterGatherOpts", + JaXPipeline( + full_optimization_pass_pipeline( + enable_scatter_gather_optimization_passes=False + ) + ), + CurBackends, + ) -def pipelines(): +def pipelines( + *, + jaxpipe=True, + jax=True, + hloopt=True, + partopt=True, + ipartopt=True, + defopt=True, + idefopt=True, + noscattergather=False, +): setup_backends() - return [ - get_pipeline("JaxPipe"), - get_pipeline("Jax"), - get_pipeline("HLOOpt"), - get_pipeline("PartOpt"), - get_pipeline("IPartOpt"), - get_pipeline("DefOpt"), - get_pipeline("IDefOpt"), - ] + pipelines = [] + if jaxpipe: + pipelines.append(get_pipeline("JaxPipe")) + if jax: + pipelines.append(get_pipeline("Jax")) + if hloopt: + pipelines.append(get_pipeline("HLOOpt")) + if partopt: + pipelines.append(get_pipeline("PartOpt")) + if ipartopt: + pipelines.append(get_pipeline("IPartOpt")) + if defopt: + pipelines.append(get_pipeline("DefOpt")) + if idefopt: + pipelines.append(get_pipeline("IDefOpt")) + if noscattergather: + pipelines.append(get_pipeline("NoScatterGatherOpts")) + + return pipelines def no_newxla(x): @@ -552,6 +582,7 @@ def __init__(self, *args, **kwargs): self.mlirad_rev = True self.results = [] self.skip_test_assert = False + self.use_xprof = True def pretty_print_table(self, name, pname, backend, key, time): print_str = "{:<20}\t{:<20}\t{:<15}\t{:<10}\t{:<15.8f}".format( @@ -648,7 +679,10 @@ def harness(self, name, in_fn, ins, dins, douts): recursive_check(self, ao, primres, "Primal " + pname) runtime = profile_compiled_function( - rfn_enzyme, ins_backend, nrepeat=self.repeat + rfn_enzyme, + ins_backend, + nrepeat=self.repeat, + use_xprof=self.use_xprof, )["avg_time_s"] self.pretty_print_table(name, pname, backend, "Primal", runtime) @@ -695,7 +729,10 @@ def harness(self, name, in_fn, ins, dins, douts): recursive_check(self, tangents, fwdres, "Forward " + pname) runtime = profile_compiled_function( - fwd_enzyme, all_ins, nrepeat=self.repeat + fwd_enzyme, + all_ins, + nrepeat=self.repeat, + use_xprof=self.use_xprof, )["avg_time_s"] self.pretty_print_table( name, pname, backend, "Forward", runtime @@ -750,7 +787,10 @@ def harness(self, name, in_fn, ins, dins, douts): recursive_check(self, grads, revres, "Reverse " + pname) runtime = profile_compiled_function( - rev_enzyme, all_ins, nrepeat=self.repeat + rev_enzyme, + all_ins, + nrepeat=self.repeat, + use_xprof=self.use_xprof, )["avg_time_s"] self.pretty_print_table( name, pname, backend, "PreRev", runtime @@ -795,7 +835,10 @@ def harness(self, name, in_fn, ins, dins, douts): recursive_check(self, grads, revres) runtime = profile_compiled_function( - rev_enzyme, all_ins, nrepeat=self.repeat + rev_enzyme, + all_ins, + nrepeat=self.repeat, + use_xprof=self.use_xprof, )["avg_time_s"] self.pretty_print_table( name, pname, backend, "PostRev", runtime @@ -847,7 +890,10 @@ def harness(self, name, in_fn, ins, dins, douts): recursive_check(self, grads, revres) runtime = profile_compiled_function( - rev_enzyme, all_ins, nrepeat=self.repeat + rev_enzyme, + all_ins, + nrepeat=self.repeat, + use_xprof=self.use_xprof, )["avg_time_s"] self.pretty_print_table( name, pname, backend, "BothRev", runtime diff --git a/test/xprof_utils.py b/test/xprof_utils.py index a93a511969..b5206c60c3 100644 --- a/test/xprof_utils.py +++ b/test/xprof_utils.py @@ -23,6 +23,7 @@ def profile_function( nrepeat: int = 1, warmup: int = 1, trace_dir: str | None = None, + use_xprof: bool = True, ) -> dict[str, Any]: """ Profile a JAX function and return timing data. @@ -34,6 +35,7 @@ def profile_function( nrepeat: Number of times to run the function during profiling warmup: Number of warmup runs before profiling (to ensure compilation is done) trace_dir: Directory to save traces. If None, uses a temporary directory. + use_xprof: Whether to use xprof for profiling Returns: A dictionary containing: @@ -56,7 +58,7 @@ def profile_function( for _ in range(warmup): jax.block_until_ready(compiled_fn(*args, **kwargs)) - profile_compiled_function(compiled_fn, args, kwargs, nrepeat, trace_dir) + profile_compiled_function(compiled_fn, args, kwargs, nrepeat, trace_dir, use_xprof) def profile_compiled_function( @@ -65,6 +67,7 @@ def profile_compiled_function( kwargs: dict | None = None, nrepeat: int = 1, trace_dir: str | None = None, + use_xprof: bool = True, ) -> dict[str, Any]: """ Profile a JAX function and return timing data. @@ -75,6 +78,7 @@ def profile_compiled_function( kwargs: Keyword arguments to pass to the function nrepeat: Number of times to run the function during profiling trace_dir: Directory to save traces. If None, uses a temporary directory. + use_xprof: Whether to use xprof for profiling Returns: A dictionary containing: @@ -91,7 +95,7 @@ def profile_compiled_function( if kwargs is None: kwargs = {} - if not XPROF_AVAILABLE: + if not XPROF_AVAILABLE or not use_xprof: warnings.warn("xprof not found, falling back to timeit for profiling.") # Fallback to timeit times = []