11using SparseArrays
2+ using ModelingToolkit: iscall, operation
23
34"""
45$(TYPEDSIGNATURES)
@@ -15,7 +16,7 @@ function get_deqs(sys::ODESystem)::Tuple{Vector{Equation}, Vector{Equation}}
1516 deqs = Equation[]
1617 others = Equation[]
1718 for eq in eqs
18- if istree (eq. lhs) && operation (eq. lhs) isa Differential
19+ if iscall (eq. lhs) && operation (eq. lhs) isa Differential
1920 push! (deqs, eq)
2021 else
2122 push! (others, eq)
2728"""
2829$(SIGNATURES)
2930
30- Given a vector of expressions `exprs` and variables `vars`,
31- returns a sparse coefficient matrix `A`, constant terms `c` and nonlinear terms `n`,
32- such that `exprs = A * vars + c + n`,
33- where the constant terms do not contain any variables in `vars`.
31+ Returns `true` is `expr` contains variables in `dvs` only and does not contain `iv`.
32+
33+ """
34+ function only_dvs (expr, dvs, iv)
35+ if isequal (expr, iv)
36+ return false
37+ elseif expr in dvs
38+ return true
39+ elseif SymbolicUtils. iscall (expr)
40+ args = arguments (expr)
41+ for arg in args
42+ if only_dvs (arg, dvs, iv)
43+ return true
44+ end
45+ end
46+ end
47+ return false
48+ end
49+
50+ """
51+ $(SIGNATURES)
52+
53+ Given a vector of expressions `exprs`, variables `vars` and a single variable `iv`,
54+ where `vars(iv)` are dependent variables of `iv`,
55+ returns a sparse coefficient matrix `A`, other terms `g` and nonlinear terms `F`,
56+ such that `exprs = A * vars(iv) + g(iv) + F(vars(iv))`,
57+ where the nonlinear terms are functions of `vars` only and do not contain `iv`.
3458
3559Variables in `vars` must be unique.
3660"""
37- function linear_terms (exprs:: AbstractVector , vars)
61+ function separate_terms (exprs:: AbstractVector , vars, iv )
3862 vars = Symbolics. unwrap .(vars)
3963 exprs = Symbolics. unwrap .(exprs)
64+ # expand is helpful for separating terms but is harmful for generating efficient runtime functions
65+ exprs = expand .(exprs)
4066 linear_I = Int[] # row idx for sparse matrix
4167 linear_J = Int[] # col idx for sparse matrix
4268 linear_V = Float64[] # values
@@ -55,48 +81,48 @@ function linear_terms(exprs::AbstractVector, vars)
5581 nothing
5682 end
5783
58- const_terms = similar (exprs, Num) # create a vector of the same size
59- const_terms .= 0 # manually set to Int 0 because `Num` doesn't have a corresponding zero
84+ other_terms = similar (exprs, Num) # create a vector of the same size
85+ other_terms .= 0 # manually set to Int 0 because `Num` doesn't have a corresponding zero
6086 nonlinear_terms = similar (exprs, Num)
6187 nonlinear_terms .= 0
6288
63- # check if the expr is a constant or nolinear term about vars
89+ # check if the expr is a nolinear term about vars only
6490 # and add it to the corresponding collection
65- @inline function const_nonlinear (i, expr)
66- # expr is nonlinear if it contains any variable in vars
67- if Symbolics . has_vars (expr, vars)
91+ @inline function other_nonlinear (i, expr)
92+ # expr is nonlinear if it contains vars only
93+ if only_dvs (expr, vars, iv )
6894 nonlinear_terms[i] += expr
69- else # expr is constant if it doesn't have vars
70- const_terms [i] += expr
95+ else
96+ other_terms [i] += expr
7197 end
7298 nothing
7399 end
74100
75101 for (i, expr) in enumerate (exprs)
76102 if expr isa Number # just a number, e.g. Int, Float64
77- const_terms [i] = expr
103+ other_terms [i] = expr
78104 elseif expr in vars # expr is a variables in vars
79105 push_sparse_coeff (i, expr, 1 )
80106 elseif SymbolicUtils. ismul (expr) && length (expr. dict) == 1
81107 base, exp = first (expr. dict)
82108 if base in vars && exp == 1 # a var with a coeff
83109 push_sparse_coeff (i, base, expr. coeff)
84110 else
85- const_nonlinear (i, expr)
111+ other_nonlinear (i, expr)
86112 end
87113 elseif SymbolicUtils. isadd (expr)
88- const_terms [i] += expr. coeff
114+ other_terms [i] += expr. coeff
89115 for (term, coeff) in expr. dict
90116 if term in vars
91117 push_sparse_coeff (i, term, coeff)
92118 else
93- const_nonlinear (i, term * coeff)
119+ other_nonlinear (i, term * coeff)
94120 end
95121 end
96122 else
97- const_nonlinear (i, expr)
123+ other_nonlinear (i, expr)
98124 end
99125 end
100126 linear = sparse (linear_I, linear_J, linear_V, length (exprs), length (vars))
101- return linear, const_terms , nonlinear_terms
102- end
127+ return linear, other_terms , nonlinear_terms
128+ end
0 commit comments