diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index bedcd42c3..e5e58699c 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -20,6 +20,7 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" diff --git a/lib/SimpleNonlinearSolve/src/halley.jl b/lib/SimpleNonlinearSolve/src/halley.jl index 773f4b569..1e00c2234 100644 --- a/lib/SimpleNonlinearSolve/src/halley.jl +++ b/lib/SimpleNonlinearSolve/src/halley.jl @@ -74,7 +74,7 @@ function SciMLBase.__solve( end aᵢ = J_fact \ NLBUtils.safe_vec(fx) - hvvp = Utils.compute_hvvp(prob, autodiff, fx_cache, x, aᵢ) + hvvp = Utils.compute_hvvp(prob, autodiff, fx_cache, NLBUtils.safe_vec(x), aᵢ) bᵢ = J_fact \ NLBUtils.safe_vec(hvvp) cᵢ_ = NLBUtils.safe_vec(cᵢ) diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 1090ac5f1..19173a9a5 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -166,7 +166,8 @@ function compute_hvvp(prob, autodiff, fx, x, dir) jvp_fn = if SciMLBase.isinplace(prob) @closure (u, p) -> begin du = NLBUtils.safe_similar(fx, promote_type(eltype(fx), eltype(u))) - return only(DI.pushforward(prob.f, du, autodiff, u, (dir,), Constant(p))) + return only(DI.pushforward( + prob.f, NLBUtils.safe_vec(du), autodiff, u, (dir,), Constant(p))) end else @closure (u, p) -> only(DI.pushforward(prob.f, autodiff, u, (dir,), Constant(p)))