Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions builddeps/requirements_lock_3_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ decorator==5.2.1 \
--hash=sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360 \
--hash=sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a
# via gcsfs
diffrax==0.7.0 \
--hash=sha256:aa9645c40552f11a2d32042ef6b9fcd53c1f0f6bdbe32d37cb788669ca9910be \
--hash=sha256:f3bcc578cd92a9ca86fc6f5a54c1de76c1ba62f74de69b56002624bf205316f1
# via jaxley-mech
dinosaur==1.3.5 \
--hash=sha256:aa3830f66a7ceb5cb900689d9b0717100eea74ae4d04f206a9fa20408cee3dc9 \
--hash=sha256:fd75996d62104d5c602a4f2643a1154268e6cd6ed9fd1c295aab679c6fba60b3
Expand All @@ -405,6 +409,14 @@ einops==0.8.1 \
--hash=sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737 \
--hash=sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84
# via jax-md
equinox==0.13.2 \
--hash=sha256:509ad744ff99b7c684d45230d6890f9e78eac1a556d7a06db1eff664a3cac74f \
--hash=sha256:bc1ee687e4841945d8b776664403839639a05e2f2c02c1da353ff3386e0e43b0
# via
# diffrax
# jaxley-mech
# lineax
# optimistix
etils[epath]==1.13.0 \
--hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \
--hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb
Expand Down Expand Up @@ -910,21 +922,34 @@ jaraco-functools==4.4.0 \
--hash=sha256:9eec1e36f45c818d9bf307c8948eb03b2b56cd44087b3cdc989abca1f20b9176 \
--hash=sha256:da21933b0417b89515562656547a77b4931f98176eb173644c0d35032a33d6bb
# via cheroot
jax[cpu]==0.8.2 \
--hash=sha256:1a685ded06a8223a7b52e45e668e406049dbbead02873f2b5a4d881ba7b421ae \
--hash=sha256:d0478c5dc74406441efcd25731166a65ee782f13c352fa72dc7d734351909355
# via
# -r builddeps/requirements.in
# jaxley-mech
jax[cuda12]==0.8.2 \
--hash=sha256:1a685ded06a8223a7b52e45e668e406049dbbead02873f2b5a4d881ba7b421ae \
--hash=sha256:d0478c5dc74406441efcd25731166a65ee782f13c352fa72dc7d734351909355
# via
# -r builddeps/requirements.in
# chex
# diffrax
# dinosaur
# e3nn-jax
# equinox
# flax
# jax-md
# jaxley
# jaxley-mech
# jraph
# lineax
# neuralgcm
# optax
# optimistix
# orbax-checkpoint
# tree-math
# tridiax
jax-cuda12-pjrt==0.8.2 \
--hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \
--hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977
Expand All @@ -947,6 +972,16 @@ jax-md==0.2.27 \
--hash=sha256:3506cf7c07b84d6c9cf09243097bef465c81122a23ca8cc78a3627c8b9d97322 \
--hash=sha256:efbefa5089a995a5c02405a4c930ba42f8eaf9322482998b5a422e45f631a0ab
# via -r builddeps/test-requirements.txt
jaxley==0.13.0 \
--hash=sha256:0d9247b340b402f974aad827e0cd79e32c5cd66d7295d95514792a108e15f00b \
--hash=sha256:277f135714f1370b7246754d64687357ec443e3a944f1a96633dfd4eaaafcc3e
# via
# -r builddeps/test-requirements.txt
# jaxley-mech
jaxley-mech==0.3.1 \
--hash=sha256:bd46cb2f02d1f76af56406ef83c464b6f9fc9742625cd88371a1923e14f601e8 \
--hash=sha256:cc5eda21c8521e32795526f9f85ca52941899449b0a491d3ffdb321f3f0c8cbd
# via -r builddeps/test-requirements.txt
jaxlib==0.8.2 \
--hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \
--hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \
Expand Down Expand Up @@ -980,6 +1015,14 @@ jaxlib==0.8.2 \
# jraph
# neuralgcm
# optax
jaxtyping==0.3.5 \
--hash=sha256:8150ad5b72b62fa63f573d492a79e9e455f070abe3b260f7dc15270b3eb9bba6 \
--hash=sha256:862c39fa2e526274e82dc96ee8dbe9369dadb651ab1e05d95bd685acb4e2ef02
# via
# diffrax
# equinox
# lineax
# optimistix
jmp==0.0.4 \
--hash=sha256:5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730 \
--hash=sha256:6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d
Expand Down Expand Up @@ -1097,6 +1140,12 @@ kiwisolver==1.4.9 \
--hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \
--hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220
# via matplotlib
lineax==0.0.8 \
--hash=sha256:1bd21d6c41afda233769d1c1096329ee75181825c9136be65c92b41f6daa1ddb \
--hash=sha256:bb2778066b8882acc88ff569d8e415bc5aa387f751b14ae262c9f9607d7f25bb
# via
# diffrax
# optimistix
locket==1.0.0 \
--hash=sha256:5c0d4c052a8bbbf750e056a8e65ccd309086f4f0f18a2eac306a8dfa4112a632 \
--hash=sha256:b6c819a722f7b6bd955b80781788e4a66a55628b858d347536b7e81325a3a5e3
Expand Down Expand Up @@ -1252,7 +1301,10 @@ matplotlib==3.10.8 \
--hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \
--hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \
--hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7
# via pymatgen
# via
# jaxley
# jaxley-mech
# pymatgen
mdurl==0.1.2 \
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
Expand Down Expand Up @@ -1546,7 +1598,9 @@ nest-asyncio==1.6.0 \
networkx==3.6.1 \
--hash=sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509 \
--hash=sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762
# via pymatgen
# via
# jaxley
# pymatgen
neuralgcm==1.2.2 \
--hash=sha256:24edbbb5d21e2d17a7475738f84602885eb011af3a23c33df293b2c5d10ac11c \
--hash=sha256:795297260a5aff05708e855fe8cb27db7cc0f514e9c34e373e4ba378732327e5
Expand Down Expand Up @@ -1640,6 +1694,8 @@ numpy==2.1.3 \
# flax
# jax
# jax-md
# jaxley
# jaxley-mech
# jaxlib
# jmp
# jraph
Expand All @@ -1657,6 +1713,7 @@ numpy==2.1.3 \
# spglib
# tensorstore
# treescope
# tridiax
# xarray
# xarray-tensorstore
# zarr
Expand Down Expand Up @@ -1742,6 +1799,10 @@ optax==0.2.6 \
# flax
# jax-md
# neuralgcm
optimistix==0.0.11 \
--hash=sha256:acb4fb23b598db03e376900fcb61aee8dd511de41411e849661c0ffe9e4cd1c6 \
--hash=sha256:cfce0de98e7e9fdbcc2ce6d49a9f82cd3166fd0eee29c0c7a1983f8aefd37757
# via diffrax
orbax-checkpoint==0.11.31 \
--hash=sha256:b00e39cd61cbd6c7c78b091ccac0ed1bbf3cf7788e761618e7070761195bfcc0 \
--hash=sha256:f021193a619782655798bc4a285f40612f6fe647ddeb303d1f49cdbc5645e319
Expand Down Expand Up @@ -1906,6 +1967,8 @@ pandas==2.3.3 \
--hash=sha256:f8bfc0e12dc78f777f323f55c58649591b2cd0c43534e8355c51d3fede5f4dee
# via
# dinosaur
# jaxley
# jaxley-mech
# neuralgcm
# pymatgen
# xarray
Expand Down Expand Up @@ -2743,18 +2806,26 @@ treescope==0.1.10 \
--hash=sha256:20f74656f34ab2d8716715013e8163a0da79bdc2554c16d5023172c50d27ea95 \
--hash=sha256:dde52f5314f4c29d22157a6fe4d3bd103f9cae02791c9e672eefa32c9aa1da51
# via flax
tridiax==0.2.1 \
--hash=sha256:311b0ed41671303197e219019fb9d22d6b31c841ddf5fdd1ec2601e09ed4e750 \
--hash=sha256:95a8c6d003cdd694487c99e5ba2c43d4fb4dfbe3a3df96e9ac2c80c1c4aaecd1
# via jaxley
typing-extensions==4.15.0 \
--hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \
--hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548
# via
# aiosignal
# chex
# diffrax
# equinox
# etils
# flax
# flexcache
# flexparser
# grpcio
# lineax
# numcodecs
# optimistix
# orbax-checkpoint
# pint
# spglib
Expand All @@ -2771,6 +2842,13 @@ urllib3==2.6.2 \
--hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \
--hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd
# via requests
wadler-lindig==0.1.7 \
--hash=sha256:81d14d3fe77d441acf3ebd7f4aefac20c74128bf460e84b512806dccf7b2cd55 \
--hash=sha256:e3ec83835570fd0a9509f969162aeb9c65618f998b1f42918cfc8d45122fe953
# via
# diffrax
# equinox
# jaxtyping
werkzeug==3.1.4 \
--hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \
--hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e
Expand Down
3 changes: 3 additions & 0 deletions builddeps/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ protobuf >= 6

jax-md; sys_platform == 'linux'

jaxley; sys_platform == 'linux'
jaxley_mech; sys_platform == 'linux'

# maxtext can't be installed concurrently, but installing it fixes
# https://github.com/wsmoses/maxtext/archive/bc50722be7d89e4003bd830b80e4ac968be658eb.tar.gz; python_version < "3.12"
# maxtext; python_version < "3.13"
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/Implementations/CHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def IsInf : HLOInst<"IsInfOp">;
def IsNegInf : HLOInst<"IsNegInfOp">;
def IsPosInf : HLOInst<"IsPosInfOp">;
def Lgamma : HLOInst<"LgammaOp">;
def Square : HLOInst<"SquareOp">;

/// CHLO - broadcasting compare operation
def BroadcastCompare : HLOInst<"BroadcastCompareOp">;
Expand Down Expand Up @@ -142,3 +143,7 @@ def : HLODerivative<"SinhOp", (Op $x), [(Mul (DiffeRet), (Cosh $x))]>;
def : HLODerivative<"TanOp", (Op $x), [
(Div (DiffeRet), (Mul (Cos $x), (Cos $x)))
]>;

def : HLODerivative<"SquareOp", (Op $x), [
(Mul (DiffeRet), (Mul (HLOConstantFP<"2"> $x), $x))
]>;
20 changes: 20 additions & 0 deletions test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,25 @@ py_test(
deps = TEST_DEPS,
)

py_test(
name = "jaxley_test",
timeout = "eternal",
srcs = [
"jaxley_test.py",
"test_utils.py",
"xprof_utils.py",
],
imports = ["."],
tags = ["exclusive"],
deps = TEST_DEPS + select({
"@bazel_tools//src/conditions:linux_x86_64": [
"@pypi_jaxley//:pkg",
"@pypi_jaxley_mech//:pkg",
],
"//conditions:default": [],
}),
)

py_test(
name = "jaxmd",
timeout = "eternal",
Expand Down Expand Up @@ -192,6 +211,7 @@ test_suite(
name = "python_tests",
tests = [
":bench_vs_xla",
":jaxley_test",
":jaxmd",
":llama",
":neuralgcm_test",
Expand Down
Loading
Loading