From 2b1cd1dd3907f845148ae3f3f49883f63320cef2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 10 Oct 2024 11:05:59 -0400 Subject: [PATCH 1/2] fix: allow u0 to be duals --- Project.toml | 2 +- src/internal/forward_diff.jl | 50 +++++++++++++++++++----------------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index 1651aab2a..419a0ee09 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.15.1" +version = "3.15.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/internal/forward_diff.jl b/src/internal/forward_diff.jl index 190c80645..5da5798ff 100644 --- a/src/internal/forward_diff.jl +++ b/src/internal/forward_diff.jl @@ -2,16 +2,33 @@ import SimpleNonlinearSolve: __nlsolve_ad, __nlsolve_dual_soln, __nlsolve_∂f_∂p, __nlsolve_∂f_∂u -function SciMLBase.solve( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::Union{Nothing, AbstractNonlinearAlgorithm}, - args...; - kwargs...) where {T, V, P, iip} - sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) - return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) +for (uType, pType) in [ + (Union{Number, <:AbstractArray}, Union{<:Dual, <:AbstractArray{<:Dual}}), + (Union{<:Dual, <:AbstractArray{<:Dual}}, Union{<:Dual, <:AbstractArray{<:Dual}}), + (Union{<:Dual, <:AbstractArray{<:Dual}}, Any), +] + @eval begin + function SciMLBase.solve( + prob::NonlinearProblem{<:$(uType), iip, <:$(pType)}, + alg::Union{Nothing, AbstractNonlinearAlgorithm}, + args...; kwargs...) where {iip} + sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) + dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) + end + + function SciMLBase.init( + prob::NonlinearProblem{<:$(uType), iip, <:$(pType)}, + alg::Union{Nothing, AbstractNonlinearAlgorithm}, + args...; kwargs...) where {iip} + p = __value(prob.p) + newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)) + end + end end @concrete mutable struct NonlinearSolveForwardDiffCache @@ -35,19 +52,6 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache; return cache end -function SciMLBase.init( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::Union{Nothing, AbstractNonlinearAlgorithm}, - args...; - kwargs...) where {T, V, P, iip} - p = __value(prob.p) - newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...) - cache = init(newprob, alg, args...; kwargs...) - return NonlinearSolveForwardDiffCache( - cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)) -end - function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache) sol = solve!(cache.cache) prob = cache.prob From b8c4c353f0132d647a6a9415e467bb92d2aba386 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 10 Oct 2024 13:06:43 -0400 Subject: [PATCH 2/2] fix: capture more forwarddiff types --- src/internal/forward_diff.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/internal/forward_diff.jl b/src/internal/forward_diff.jl index 5da5798ff..bf0abbc77 100644 --- a/src/internal/forward_diff.jl +++ b/src/internal/forward_diff.jl @@ -3,9 +3,9 @@ import SimpleNonlinearSolve: __nlsolve_ad, __nlsolve_dual_soln, __nlsolve_∂f_ __nlsolve_∂f_∂u for (uType, pType) in [ - (Union{Number, <:AbstractArray}, Union{<:Dual, <:AbstractArray{<:Dual}}), + (Union{<:Number, <:AbstractArray}, Union{<:Dual, <:AbstractArray{<:Dual}}), (Union{<:Dual, <:AbstractArray{<:Dual}}, Union{<:Dual, <:AbstractArray{<:Dual}}), - (Union{<:Dual, <:AbstractArray{<:Dual}}, Any), + (Union{<:Dual, <:AbstractArray{<:Dual}}, Union{<:Number, <:AbstractArray}) ] @eval begin function SciMLBase.solve(