Skip to content

Commit f6b6c00

Browse files
ChrisRackauckasoscardssmith
authored andcommitted
Setup NonlinearSolveAlg with jacobian reuse
1 parent 4abda1b commit f6b6c00

File tree

8 files changed

+33
-13
lines changed

8 files changed

+33
-13
lines changed

lib/OrdinaryDiffEqCore/src/misc_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ function get_differential_vars(f, u)
133133
end
134134

135135
isnewton(::Any) = false
136+
isnonlinearsolve(::Any) = false
136137

137138
function _bool_to_ADType(::Val{true}, ::Val{CS}, _) where {CS}
138139
Base.depwarn(

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici
3838
OrdinaryDiffEqAdaptiveExponentialAlgorithm, @unpack,
3939
AbstractNLSolver, nlsolve_f, issplit,
4040
concrete_jac, unwrap_alg, OrdinaryDiffEqCache, _vec, standardtag,
41-
isnewton, _unwrap_val,
41+
isnewton, isnonlinearsolve, _unwrap_val,
4242
set_new_W!, set_W_γdt!, alg_difftype, unwrap_cache, diffdir,
4343
get_W, isfirstcall, isfirststage, isJcurrent,
4444
get_new_W_γdt_cutoff,

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ function do_newJW(integrator, alg, nlsolver, repeat_step)::NTuple{2, Bool}
468468
return true, true
469469
end
470470
# TODO: add `isJcurrent` support for Rosenbrock solvers
471-
if !isnewton(nlsolver)
471+
if !isnewton(nlsolver) && !isnonlinearsolve(nlsolver)
472472
isfreshJ = !(integrator.alg isa CompositeAlgorithm) &&
473473
(integrator.iter > 1 && errorfail && !integrator.u_modified)
474474
return !isfreshJ, true

lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ using OrdinaryDiffEqCore: resize_nlsolver!, _initialize_dae!,
5050

5151
import OrdinaryDiffEqCore: _initialize_dae!, isnewton, get_W, isfirstcall, isfirststage,
5252
isJcurrent, get_new_W_γdt_cutoff, resize_nlsolver!, apply_step!,
53-
postamble!
53+
postamble!, isnonlinearsolve
5454

5555
import OrdinaryDiffEqDifferentiation: update_W!, is_always_new, build_uf, build_J_W,
5656
WOperator, StaticWOperator, wrapprecs,
5757
build_jac_config, dolinsolve, alg_autodiff,
58-
resize_jac_config!
58+
resize_jac_config!, do_newJW
5959

6060
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA,
6161
StaticMatrix

lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,15 @@ end
9696
@unpack z, tmp, ztmp, γ, α, cache, method = nlsolver
9797
@unpack tstep, invγdt = cache
9898

99+
new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false)
100+
if is_always_new(nlsolver) || new_jac || new_W
101+
recompute_jacobian = true
102+
else
103+
recompute_jacobian = false
104+
end
105+
99106
nlcache = nlsolver.cache.cache
100-
step!(nlcache)
107+
step!(nlcache; recompute_jacobian)
101108
nlsolver.ztmp = nlcache.u
102109

103110
ustep = compute_ustep(tmp, γ, z, method)
@@ -118,9 +125,16 @@ end
118125
@unpack z, tmp, ztmp, γ, α, cache, method = nlsolver
119126
@unpack tstep, invγdt, atmp, ustep = cache
120127

121-
nlstep_data = integrator.f.nlstep_data
128+
new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false)
129+
if is_always_new(nlsolver) || new_jac || new_W
130+
recompute_jacobian = true
131+
else
132+
recompute_jacobian = false
133+
end
134+
122135
nlcache = nlsolver.cache.cache
123-
step!(nlcache)
136+
nlstep_data = integrator.f.nlstep_data
137+
step!(nlcache; recompute_jacobian)
124138

125139
if nlstep_data !== nothing
126140
nlstepsol = SciMLBase.build_solution(

lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ function nlsolve!(nlsolver::NL, integrator::SciMLBase.DEIntegrator,
105105
# don't trust θ for non-adaptive on first iter because the solver doesn't provide feedback
106106
# for us to know whether our previous nlsolve converged sufficiently well
107107
check_η_convergence = (iter > 1 ||
108-
(isnewton(nlsolver) && isadaptive(integrator.alg)))
108+
((isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && isadaptive(integrator.alg)))
109109
if (iter == 1 && ndz < 1e-5) ||
110110
(check_η_convergence && η >= zero(η) && η * ndz < κ)
111111
nlsolver.status = Convergence
@@ -114,7 +114,7 @@ function nlsolve!(nlsolver::NL, integrator::SciMLBase.DEIntegrator,
114114
end
115115
end
116116

117-
if isnewton(nlsolver) && nlsolver.status == Divergence &&
117+
if (isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && nlsolver.status == Divergence &&
118118
!isJcurrent(nlsolver, integrator)
119119
nlsolver.status = TryAgain
120120
nlsolver.nfails += 1

lib/OrdinaryDiffEqNonlinearSolve/src/type.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,5 @@ mutable struct NonlinearSolveCache{uType, tType, rateType, tType2, P, C} <:
218218
invγdt::tType2
219219
prob::P
220220
cache::C
221+
new_W::Bool
221222
end

lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ isnewton(nlsolver::AbstractNLSolver) = isnewton(nlsolver.cache)
1414
isnewton(::AbstractNLSolverCache) = false
1515
isnewton(::Union{NLNewtonCache, NLNewtonConstantCache}) = true
1616

17+
isnonlinearsolve(nlsolver::AbstractNLSolver) = isnonlinearsolve(nlsolver.cache)
18+
isnonlinearsolve(::AbstractNLSolverCache) = false
19+
isnonlinearsolve(::NonlinearSolveCache) = true
20+
1721
is_always_new(nlsolver::AbstractNLSolver) = is_always_new(nlsolver.alg)
1822
check_div(nlsolver::AbstractNLSolver) = check_div(nlsolver.alg)
1923
check_div(alg) = isdefined(alg, :check_div) ? alg.check_div : true
@@ -32,9 +36,9 @@ getnfails(_) = 0
3236
getnfails(nlsolver::AbstractNLSolver) = nlsolver.nfails
3337

3438
set_new_W!(nlsolver::AbstractNLSolver, val::Bool)::Bool = set_new_W!(nlsolver.cache, val)
35-
set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}, val::Bool)::Bool = nlcache.new_W = val
39+
set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)::Bool = nlcache.new_W = val
3640
get_new_W!(nlsolver::AbstractNLSolver)::Bool = get_new_W!(nlsolver.cache)
37-
get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache})::Bool = nlcache.new_W
41+
get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache})::Bool = nlcache.new_W
3842
get_new_W!(::AbstractNLSolverCache)::Bool = true
3943

4044
get_W(nlsolver::AbstractNLSolver) = get_W(nlsolver.cache)
@@ -239,7 +243,7 @@ function build_nlsolver(
239243
NonlinearProblem(NonlinearFunction{true}(nlf), ztmp, nlp_params)
240244
end
241245
cache = init(prob, nlalg.alg)
242-
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache)
246+
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache, true)
243247
else
244248
nlcache = NLNewtonCache(ustep, tstep, k, atmp, dz, J, W, true,
245249
true, true, tType(dt), du1, uf, jac_config,
@@ -327,7 +331,7 @@ function build_nlsolver(
327331
prob = NonlinearProblem(NonlinearFunction{false}(nlf), copy(ztmp), nlp_params)
328332
cache = init(prob, nlalg.alg)
329333
nlcache = NonlinearSolveCache(
330-
nothing, tstep, nothing, nothing, invγdt, prob, cache)
334+
nothing, tstep, nothing, nothing, invγdt, prob, cache, true)
331335
else
332336
nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf,
333337
invγdt, tType(nlalg.new_W_dt_cutoff), t)

0 commit comments

Comments
 (0)