diff --git a/Project.toml b/Project.toml index 73cc681..ebbe4ae 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ ForwardDiff = "0.10.36" Functors = "0.4.11" JET = "0.9.6" MLDataDevices = "1.0.0" +ReTestItems = "1.24.0" ReverseDiff = "1.15.3" Test = "1.10" Tracker = "0.2.34" @@ -37,10 +38,8 @@ Zygote = "0.6.70" julia = "1.10" [extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Documenter", "ExplicitImports", "ReTestItems"] +test = ["ReTestItems", "Test"] diff --git a/src/autodiff.jl b/src/autodiff.jl index 455d8d6..6e2b66d 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -25,7 +25,7 @@ function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F} Enzyme.autodiff(ad.mode, f, Enzyme.Active, args_activity...) return Tuple(map(enumerate(args)) do (i, x) needs_gradient(x) && return args_activity[i].dval - return CRC.ZeroTangent() + return CRC.NoTangent() end) end @@ -78,6 +78,35 @@ function gradient(f::F, grad_fn::GFN, args...) where {F, GFN <: Function} end # Main Functionality to Test Gradient Correctness +""" + test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...) + +Test the gradients of `f` with respect to `args` using the specified backends. + +## Arguments + + - `f`: The function to test the gradients of. + - `args`: The arguments to test the gradients of. Only `AbstractArray`s are considered + for gradient computation. Gradients wrt all other arguments are assumed to be + `NoTangent()`. + +## Keyword Arguments + + - `skip_backends`: A list of backends to skip. + - `broken_backends`: A list of backends to treat as broken. + - `kwargs`: Additional keyword arguments to pass to `check_approx`. + +## Example + +```julia +julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) + +julia> x = (; t=rand(10), x=(z=[2.0],)) + +julia> test_gradients(f, 1.0, x, nothing) + +``` +""" function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...) on_gpu = get_device_type(args) isa AbstractGPUDevice total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) @@ -102,14 +131,14 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs # Test the gradients ∂args_gt = gradient(f, backends[1], args...) # Should be Zygote in most cases - @assert backends[1] ∉ broken_backends "first backend cannot be broken" + @assert backends[1]∉broken_backends "first backend cannot be broken" @testset "gradtest($(f))" begin - @testset "$(backends[1]) vs $(backend)" for backend in backends[2:end] + @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] broken = backend in broken_backends @test begin ∂args = allow_unstable() do - gradient(f, backend, args...) + return gradient(f, backend, args...) end check_approx(∂args, ∂args_gt; kwargs...) end broken=broken diff --git a/src/jet.jl b/src/jet.jl index db6f769..23963bd 100644 --- a/src/jet.jl +++ b/src/jet.jl @@ -7,7 +7,7 @@ const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) This sets `target_modules` for all JET tests when using [`@jet`](@ref). """ function jet_target_modules!(list::Vector{String}; force::Bool=false) - if JET_TARGET_MODULES[] !== nothing && !force + if JET_TARGET_MODULES[] === nothing || (force && JET_TARGET_MODULES[] !== nothing) JET_TARGET_MODULES[] = list @info "JET_TARGET_MODULES set to $list" return list diff --git a/test/unit_tests.jl b/test/unit_tests.jl index e69de29..f435a4d 100644 --- a/test/unit_tests.jl +++ b/test/unit_tests.jl @@ -0,0 +1,13 @@ +@testitem "@jet" begin + LuxTestUtils.jet_target_modules!(["LuxTestUtils"]) + + @jet sum([1, 2, 3]) target_modules=(Base, Core) +end + +@testitem "test_gradients" begin + f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) + + x = (; t=rand(10), x=(z=[2.0],)) + + test_gradients(f, 1.0, x, nothing) +end