From f186d85b8870027e02f3d6b1f2583320725b662a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 5 Feb 2024 19:26:39 -0500 Subject: [PATCH] Fix CA in FiniteDifferences --- Project.toml | 2 +- src/LuxTestUtils.jl | 23 +++++++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index b398e32..d92bf94 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.14" +version = "0.1.15" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/src/LuxTestUtils.jl b/src/LuxTestUtils.jl index 9a29c1f..77ed892 100644 --- a/src/LuxTestUtils.jl +++ b/src/LuxTestUtils.jl @@ -269,7 +269,8 @@ function test_gradients_expr(__module__, __source__, f, args...; skip=skip_reverse_diff) reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff - arr_len = length.(filter(Base.Fix2(isa, AbstractArray) ∘ __correct_arguments, + arr_len = length.(filter(Base.Fix2(isa, AbstractArray) ∘ + Base.Fix1(__correct_arguments, identity), tuple($(esc.(args)...)))) large_arrays = any(x -> x ≥ $large_array_length, arr_len) || sum(arr_len) ≥ $max_total_array_size @@ -333,8 +334,8 @@ end __test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr) -__correct_arguments(x::AbstractArray) = x -function __correct_arguments(x::NamedTuple) +__correct_arguments(f::F, x::AbstractArray) where {F} = x +function __correct_arguments(f::F, x::NamedTuple) where {F} cpu_dev = cpu_device() gpu_dev = gpu_device() xc = cpu_dev(x) @@ -343,7 +344,7 @@ function __correct_arguments(x::NamedTuple) typeof(xc) == typeof(x) && return ca return gpu_dev(ca) end -__correct_arguments(x) = x +__correct_arguments(f::F, x) where {F} = x __uncorrect_arguments(x::ComponentArray, ::NamedTuple, z::ComponentArray) = NamedTuple(x) function __uncorrect_arguments(x::AbstractArray, nt::NamedTuple, z::ComponentArray) @@ -351,11 +352,11 @@ function __uncorrect_arguments(x::AbstractArray, nt::NamedTuple, z::ComponentArr end __uncorrect_arguments(x, y, z) = x -function __gradient(gradient_function, f, args...; skip::Bool) +function __gradient(gradient_function::F, f, args...; skip::Bool) where {F} if skip return ntuple(_ -> GradientComputationSkipped(), length(args)) else - corrected_args = map(__correct_arguments, args) + corrected_args = map(Base.Fix1(__correct_arguments, gradient_function), args) aa_inputs = [map(Base.Fix2(isa, AbstractArray), corrected_args)...] __aa_input_idx = cumsum(aa_inputs) if sum(aa_inputs) == length(args) @@ -392,6 +393,16 @@ function _finitedifferences_gradient(f, args...) args...)) end +function __correct_arguments(::typeof(_finitedifferences_gradient), x::NamedTuple) + cpu_dev = cpu_device() + gpu_dev = gpu_device() + xc = cpu_dev(x) + ca = ComponentArray(xc) + # Hacky check to see if there are any non-CPU arrays in the NamedTuple + typeof(xc) == typeof(x) && return x + return gpu_dev(x) +end + function __fdiff_compatible_function(f, ::Val{N}) where {N} N == 1 && return f inputs = ntuple(i -> Symbol("x.input_$i"), N)