Skip to content

[WIP] feat: use LinearProblem for linear SCCs in SCCNonlinearProblem #3760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 56 additions & 26 deletions src/problems/linearproblem.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
struct LinearFunction{iip, I} <: SciMLBase.AbstractSciMLFunction{iip}
interface::I
A::AbstractMatrix
b::AbstractVector
end

function LinearFunction{iip}(
sys::System; expression = Val{false}, check_compatibility = true,
sparse = false, eval_expression = false, eval_module = @__MODULE__,
checkbounds = false, cse = true, kwargs...) where {iip}
check_complete(sys, LinearProblem)
check_compatibility && check_compatible_system(LinearProblem, sys)

A, b = calculate_A_b(sys; sparse)
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
eval_module, checkbounds, cse, kwargs...)
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
eval_module, checkbounds, cse, kwargs...)
observedfun = ObservedFunctionCache(
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
cse)

if expression == Val{true}
symbolic_interface = quote
update_A = $update_A
update_b = $update_b
sys = $sys
observedfun = $observedfun
$(SciMLBase.SymbolicLinearInterface)(
update_A, update_b, sys, observedfun, nothing)
end
else
symbolic_interface = SciMLBase.SymbolicLinearInterface(
update_A, update_b, sys, observedfun, nothing)
end

return LinearFunction{iip, typeof(symbolic_interface)}(symbolic_interface, A, b)
end

function SciMLBase.LinearProblem(sys::System, op; kwargs...)
SciMLBase.LinearProblem{true}(sys, op; kwargs...)
end
Expand All @@ -14,8 +53,8 @@ function SciMLBase.LinearProblem{iip}(
check_complete(sys, LinearProblem)
check_compatibility && check_compatible_system(LinearProblem, sys)

_, u0, p = process_SciMLProblem(
EmptySciMLFunction{iip}, sys, op; check_length, expression,
f, u0, p = process_SciMLProblem(
LinearFunction{iip}, sys, op; check_length, expression,
build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype,
kwargs...)

Expand All @@ -32,25 +71,21 @@ function SciMLBase.LinearProblem{iip}(
u0_eltype = something(u0_eltype, floatT)

u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype)
symbolic_interface = f.interface
A, b = get_A_b_from_LinearFunction(
sys, f, p; eval_expression, eval_module, expression, u0_constructor)

A, b = calculate_A_b(sys; sparse)
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
eval_module, checkbounds, cse, kwargs...)
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
eval_module, checkbounds, cse, kwargs...)
observedfun = ObservedFunctionCache(
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
cse)
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
args = (; A, b, p)

return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
end

function get_A_b_from_LinearFunction(
sys::System, f::LinearFunction, p; eval_expression = false,
eval_module = @__MODULE__, expression = Val{false}, u0_constructor = identity)
@unpack A, b, interface = f
if expression == Val{true}
symbolic_interface = quote
update_A = $update_A
update_b = $update_b
sys = $sys
observedfun = $observedfun
$(SciMLBase.SymbolicLinearInterface)(
update_A, update_b, sys, observedfun, nothing)
end
get_A = build_explicit_observed_function(
sys, A; param_only = true, eval_expression, eval_module)
if sparse
Expand All @@ -61,16 +96,11 @@ function SciMLBase.LinearProblem{iip}(
A = u0_constructor(get_A(p))
b = u0_constructor(get_b(p))
else
symbolic_interface = SciMLBase.SymbolicLinearInterface(
update_A, update_b, sys, observedfun, nothing)
A = u0_constructor(update_A(p))
b = u0_constructor(update_b(p))
A = u0_constructor(interface.update_A!(p))
b = u0_constructor(interface.update_b!(p))
end

kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
args = (; A, b, p)

return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
return A, b
end

# For remake
Expand Down
61 changes: 45 additions & 16 deletions src/problems/sccnonlinearproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end

function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
eval_expression = false, eval_module = @__MODULE__, cse = true)
eval_expression = false, eval_module = @__MODULE__, cse = true, sparse = false)
ps = parameters(sys; initial_parameters = true)
rps = reorder_parameters(sys, ps)
obs_assigns = [eq.lhs ← eq.rhs for eq in obseqs]
Expand Down Expand Up @@ -39,9 +39,22 @@ end
struct SCCNonlinearFunction{iip} end

function SCCNonlinearFunction{iip}(
sys::System, _eqs, _dvs, _obs, cachesyms; eval_expression = false,
sys::System, _eqs, _dvs, _obs, cachesyms, op; eval_expression = false,
eval_module = @__MODULE__, cse = true, kwargs...) where {iip}
ps = parameters(sys; initial_parameters = true)
subsys = System(
_eqs, _dvs, ps; observed = _obs, name = nameof(sys), defaults = defaults(sys))
@set! subsys.parameter_dependencies = parameter_dependencies(sys)
if get_index_cache(sys) !== nothing
@set! subsys.index_cache = subset_unknowns_observed(
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
@set! subsys.complete = true
end
# generate linear problem instead
if isaffine(subsys)
return LinearFunction{iip}(
subsys; eval_expression, eval_module, cse, cachesyms, kwargs...)
end
rps = reorder_parameters(sys, ps)

obs_assignments = [eq.lhs ← eq.rhs for eq in _obs]
Expand All @@ -54,14 +67,6 @@ function SCCNonlinearFunction{iip}(
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
f = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)

subsys = System(_eqs, _dvs, ps; observed = _obs,
parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
if get_index_cache(sys) !== nothing
@set! subsys.index_cache = subset_unknowns_observed(
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
@set! subsys.complete = true
end

return NonlinearFunction{iip}(f; sys = subsys)
end

Expand All @@ -70,7 +75,7 @@ function SciMLBase.SCCNonlinearProblem(sys::System, args...; kwargs...)
end

function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = false,
eval_module = @__MODULE__, cse = true, kwargs...) where {iip}
eval_module = @__MODULE__, cse = true, u0_constructor = identity, kwargs...) where {iip}
if !iscomplete(sys) || get_tearing_state(sys) === nothing
error("A simplified `System` is required. Call `mtkcompile` on the system before creating an `SCCNonlinearProblem`.")
end
Expand Down Expand Up @@ -112,7 +117,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
obs = observed(sys)

_, u0, p = process_SciMLProblem(
EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, kwargs...)
EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, u0_constructor,
symbolic_u0 = true, kwargs...)

explicitfuns = []
nlfuns = []
Expand Down Expand Up @@ -223,7 +229,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
get(cachevars, T, [])
end)
f = SCCNonlinearFunction{iip}(
sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, cse, kwargs...)
sys, _eqs, _dvs, _obs, cachebufsyms, op;
eval_expression, eval_module, cse, kwargs...)
push!(nlfuns, f)
end

Expand All @@ -240,11 +247,33 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
p = rebuild_with_caches(p, templates...)
end

u0_eltype = Union{}
for x in u0
symbolic_type(x) == NotSymbolic() || continue
u0_eltype = typeof(x)
break
end
if u0_eltype == Union{}
u0_eltype = Float64
end
subprobs = []
for (f, vscc) in zip(nlfuns, var_sccs)
for (i, (f, vscc)) in enumerate(zip(nlfuns, var_sccs))
_u0 = SymbolicUtils.Code.create_array(
typeof(u0), eltype(u0), Val(1), Val(length(vscc)), u0[vscc]...)
prob = NonlinearProblem(f, _u0, p)
symbolic_idxs = findall(x -> symbolic_type(x) != NotSymbolic(), _u0)
explicitfuns[i](p, subprobs)
if f isa LinearFunction
_u0 = isempty(symbolic_idxs) ? _u0 : zeros(u0_eltype, length(_u0))
_u0 = u0_eltype.(_u0)
symbolic_interface = f.interface
A, b = get_A_b_from_LinearFunction(
sys, f, p; eval_expression, eval_module, u0_constructor)
prob = LinearProblem(A, b, p; f = symbolic_interface, u0 = _u0)
else
isempty(symbolic_idxs) || throw(MissingGuessError(dvs[vscc], _u0))
_u0 = u0_eltype.(_u0)
prob = NonlinearProblem(f, _u0, p)
end
push!(subprobs, prob)
end

Expand All @@ -254,5 +283,5 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
@set! sys.eqs = new_eqs
@set! sys.index_cache = subset_unknowns_observed(
get_index_cache(sys), sys, new_dvs, getproperty.(obs, (:lhs,)))
return SCCNonlinearProblem(subprobs, explicitfuns, p, true; sys)
return SCCNonlinearProblem(Tuple(subprobs), Tuple(explicitfuns), p, true; sys)
end
8 changes: 4 additions & 4 deletions src/systems/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1187,10 +1187,10 @@ $GENERATE_X_KWARGS
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
"""
function generate_update_A(sys::System, A::AbstractMatrix; expression = Val{true},
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, cachesyms = (), kwargs...)
ps = reorder_parameters(sys)

res = build_function_wrapper(sys, A, ps...; p_start = 1, expression = Val{true},
res = build_function_wrapper(sys, A, ps..., cachesyms...; p_start = 1, expression = Val{true},
similarto = typeof(A), kwargs...)
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
eval_expression, eval_module)
Expand All @@ -1209,10 +1209,10 @@ $GENERATE_X_KWARGS
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
"""
function generate_update_b(sys::System, b::AbstractVector; expression = Val{true},
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, cachesyms = (), kwargs...)
ps = reorder_parameters(sys)

res = build_function_wrapper(sys, b, ps...; p_start = 1, expression = Val{true},
res = build_function_wrapper(sys, b, ps..., cachesyms...; p_start = 1, expression = Val{true},
similarto = typeof(b), kwargs...)
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
eval_expression, eval_module)
Expand Down
8 changes: 4 additions & 4 deletions test/scc_nonlinear_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
@test_throws ["not compatible"] SCCNonlinearProblem(_model, [])
model = mtkcompile(model)
prob = NonlinearProblem(model, [u => zeros(8)])
sccprob = SCCNonlinearProblem(model, [u => zeros(8)])
sccprob = SCCNonlinearProblem(model, collect(u[1:5]) .=> zeros(5))
sol1 = solve(prob, NewtonRaphson())
sol2 = solve(sccprob, NewtonRaphson())
@test SciMLBase.successful_retcode(sol1)
@test SciMLBase.successful_retcode(sol2)
@test sol1[u] ≈ sol2[u]
@test_broken SciMLBase.successful_retcode(sol2)
@test_broken sol1[u] ≈ sol2[u]

sccprob = SCCNonlinearProblem{false}(model, SA[u => zeros(8)])
sccprob = SCCNonlinearProblem{false}(model, SA[(collect(u[1:5]) .=> zeros(5))...])
for prob in sccprob.probs
@test prob.u0 isa SVector
@test !SciMLBase.isinplace(prob)
Expand Down
Loading