-
-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathNonlinearSolveNLsolveExt.jl
50 lines (40 loc) · 2.11 KB
/
NonlinearSolveNLsolveExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
module NonlinearSolveNLsolveExt
using NonlinearSolve: NonlinearSolve, NLsolveJL, TraceMinimal
using NLsolve: NLsolve, OnceDifferentiable, nlsolve
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
function SciMLBase.__solve(
prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = nothing,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
store_trace::Val{StT} = Val(false), show_trace::Val{ShT} = Val(false),
trace_level = TraceMinimal(), kwargs...) where {StT, ShT}
NonlinearSolve.__test_termination_condition(termination_condition, :NLsolveJL)
f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
if prob.f.jac === nothing && alg.autodiff isa Symbol
df = OnceDifferentiable(f!, u0, resid; alg.autodiff)
else
jac! = NonlinearSolve.__construct_extension_jac(prob, alg, u0, resid; alg.autodiff)
if prob.f.jac_prototype === nothing
J = similar(
u0, promote_type(eltype(u0), eltype(resid)), length(u0), length(resid))
else
J = zero(prob.f.jac_prototype)
end
df = OnceDifferentiable(f!, jac!, vec(u0), vec(resid), J)
end
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
show_trace = ShT || alg.show_trace
store_trace = StT || alg.store_trace
extended_trace = !(trace_level isa TraceMinimal) || alg.extended_trace
original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, alg.method,
store_trace, extended_trace, alg.linesearch, alg.linsolve,
alg.factor, alg.autoscale, alg.m, alg.beta, show_trace)
f!(vec(resid), original.zero)
u = prob.u0 isa Number ? original.zero[1] : reshape(original.zero, size(prob.u0))
resid = prob.u0 isa Number ? resid[1] : resid
retcode = original.x_converged || original.f_converged ? ReturnCode.Success :
ReturnCode.Failure
stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,
original.g_calls, original.iterations)
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
end
end