Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
test: add some simple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 28, 2024
1 parent 08fe500 commit e8e62aa
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 9 deletions.
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,16 @@ ForwardDiff = "0.10.36"
Functors = "0.4.11"
JET = "0.9.6"
MLDataDevices = "1.0.0"
ReTestItems = "1.24.0"
ReverseDiff = "1.15.3"
Test = "1.10"
Tracker = "0.2.34"
Zygote = "0.6.70"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Documenter", "ExplicitImports", "ReTestItems"]
test = ["ReTestItems", "Test"]
37 changes: 33 additions & 4 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F}
Enzyme.autodiff(ad.mode, f, Enzyme.Active, args_activity...)
return Tuple(map(enumerate(args)) do (i, x)
needs_gradient(x) && return args_activity[i].dval
return CRC.ZeroTangent()
return CRC.NoTangent()
end)
end

Expand Down Expand Up @@ -78,6 +78,35 @@ function gradient(f::F, grad_fn::GFN, args...) where {F, GFN <: Function}
end

# Main Functionality to Test Gradient Correctness
"""
test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...)
Test the gradients of `f` with respect to `args` using the specified backends.
## Arguments
- `f`: The function to test the gradients of.
- `args`: The arguments to test the gradients of. Only `AbstractArray`s are considered
for gradient computation. Gradients wrt all other arguments are assumed to be
`NoTangent()`.
## Keyword Arguments
- `skip_backends`: A list of backends to skip.
- `broken_backends`: A list of backends to treat as broken.
- `kwargs`: Additional keyword arguments to pass to `check_approx`.
## Example
```julia
julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z)
julia> x = (; t=rand(10), x=(z=[2.0],))
julia> test_gradients(f, 1.0, x, nothing)
```
"""
function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...)
on_gpu = get_device_type(args) isa AbstractGPUDevice
total_length = mapreduce(__length, +, Functors.fleaves(args); init=0)
Expand All @@ -102,14 +131,14 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs
# Test the gradients
∂args_gt = gradient(f, backends[1], args...) # Should be Zygote in most cases

@assert backends[1] broken_backends "first backend cannot be broken"
@assert backends[1]broken_backends "first backend cannot be broken"

@testset "gradtest($(f))" begin
@testset "$(backends[1]) vs $(backend)" for backend in backends[2:end]
@testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end]
broken = backend in broken_backends
@test begin
∂args = allow_unstable() do
gradient(f, backend, args...)
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
end broken=broken
Expand Down
2 changes: 1 addition & 1 deletion src/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing)
This sets `target_modules` for all JET tests when using [`@jet`](@ref).
"""
function jet_target_modules!(list::Vector{String}; force::Bool=false)
if JET_TARGET_MODULES[] !== nothing && !force
if JET_TARGET_MODULES[] === nothing || (force && JET_TARGET_MODULES[] !== nothing)
JET_TARGET_MODULES[] = list
@info "JET_TARGET_MODULES set to $list"
return list
Expand Down
13 changes: 13 additions & 0 deletions test/unit_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@testitem "@jet" begin
LuxTestUtils.jet_target_modules!(["LuxTestUtils"])

@jet sum([1, 2, 3]) target_modules=(Base, Core)
end

@testitem "test_gradients" begin
f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z)

x = (; t=rand(10), x=(z=[2.0],))

test_gradients(f, 1.0, x, nothing)
end

0 comments on commit e8e62aa

Please sign in to comment.