diff --git a/Project.toml b/Project.toml index 87a7186..92c3199 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "1.3.0" +version = "1.3.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" @@ -21,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1.8.1" +ArrayInterface = "7.9" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.14" DispatchDoctor = "0.4.12" diff --git a/src/LuxTestUtils.jl b/src/LuxTestUtils.jl index dfda396..795665c 100644 --- a/src/LuxTestUtils.jl +++ b/src/LuxTestUtils.jl @@ -1,5 +1,6 @@ module LuxTestUtils +using ArrayInterface: ArrayInterface using ComponentArrays: ComponentArray, getdata, getaxes using DispatchDoctor: allow_unstable using Functors: Functors diff --git a/src/autodiff.jl b/src/autodiff.jl index 7debc94..f46136f 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -172,10 +172,10 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] local_test_expr = :([$(nameof(typeof(backend)))] - $(test_expr)) - result = if backend in skip_backends + result = if check_ad_backend_in(backend, skip_backends) Broken(:skipped, local_test_expr) elseif (soft_fail isa Bool && soft_fail) || - (soft_fail isa Vector && backend in soft_fail) + (soft_fail isa Vector && check_ad_backend_in(backend, soft_fail)) try ∂args = allow_unstable() do return gradient(f, backend, args...) @@ -189,7 +189,7 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], catch Broken(:test, local_test_expr) end - elseif backend in broken_backends + elseif check_ad_backend_in(backend, broken_backends) try ∂args = allow_unstable() do return gradient(f, backend, args...) diff --git a/src/utils.jl b/src/utils.jl index 22f0749..4327504 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -123,3 +123,9 @@ function reorder_macro_kw_params(exs) end return Tuple(exs) end + +function check_ad_backend_in(backend, backends) + backends_type = map(ArrayInterface.parameterless_type ∘ typeof, backends) + backend_type = ArrayInterface.parameterless_type(typeof(backend)) + return backend_type in backends_type +end