Skip to content

Commit 105dd72

Browse files
fix: require explicitly specifying discrete variable derivatives
1 parent 4dbd08f commit 105dd72

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

src/structural_transformation/symbolics_tearing.jl

+14
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...)
6868
eq = 0 ~ fast_substitute(
6969
ModelingToolkit.derivative(
7070
eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map)
71+
72+
vs = ModelingToolkit.vars(eq.rhs)
73+
for v in vs
74+
# parameters with unknown derivatives have a value of `nothing` in the map,
75+
# so use `missing` as the default.
76+
get(ts.param_derivative_map, v, missing) === nothing || continue
77+
_original_eq = equations(ts)[ieq]
78+
error("""
79+
Encountered derivative of discrete variable `$(only(arguments(v)))` when \
80+
differentiating equation `$(_original_eq)`. This may indicate a model error or a \
81+
missing equation of the form `$v ~ ...` that defines this derivative.
82+
""")
83+
end
84+
7185
push!(equations(ts), eq)
7286
# Analyze the new equation and update the graph/solvable_graph
7387
# First, copy the previous incidence and add the derivative terms.

src/systems/systemstructure.jl

+21-5
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
207207
fullvars::Vector
208208
structure::SystemStructure
209209
extra_eqs::Vector
210-
param_derivative_map::Dict{BasicSymbolic, Real}
210+
param_derivative_map::Dict{BasicSymbolic, Any}
211211
end
212212

213213
TransformationState(sys::AbstractSystem) = TearingState(sys)
@@ -254,6 +254,11 @@ function Base.push!(ev::EquationsView, eq)
254254
push!(ev.ts.extra_eqs, eq)
255255
end
256256

257+
function is_time_dependent_parameter(p, iv)
258+
return iv !== nothing && isparameter(p) && iscall(p) &&
259+
(args = arguments(p); length(args)) == 1 && isequal(only(args), iv)
260+
end
261+
257262
function TearingState(sys; quick_cancel = false, check = true)
258263
sys = flatten(sys)
259264
ivs = independent_variables(sys)
@@ -265,7 +270,7 @@ function TearingState(sys; quick_cancel = false, check = true)
265270
var2idx = Dict{Any, Int}()
266271
symbolic_incidence = []
267272
fullvars = []
268-
param_derivative_map = Dict{BasicSymbolic, Real}()
273+
param_derivative_map = Dict{BasicSymbolic, Any}()
269274
var_counter = Ref(0)
270275
var_types = VariableType[]
271276
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
@@ -278,11 +283,17 @@ function TearingState(sys; quick_cancel = false, check = true)
278283

279284
vars = OrderedSet()
280285
varsvec = []
286+
eqs_to_retain = trues(length(eqs))
281287
for (i, eq′) in enumerate(eqs)
282288
if eq′.lhs isa Connection
283289
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
284290
return nothing
285291
end
292+
if iscall(eq′.lhs) && (op = operation(eq′.lhs)) isa Differential &&
293+
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq′.lhs)), iv)
294+
param_derivative_map[eq′.lhs] = eq′.rhs
295+
eqs_to_retain[i] = false
296+
end
286297
if _iszero(eq′.lhs)
287298
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
288299
eq = eq′
@@ -297,9 +308,11 @@ function TearingState(sys; quick_cancel = false, check = true)
297308
any(isequal(_var), ivs) && continue
298309
if isparameter(_var) ||
299310
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
300-
if iv !== nothing && isparameter(_var) && iscall(_var) &&
301-
(args = arguments(_var); length(args)) == 1 && isequal(only(args), iv)
302-
param_derivative_map[Differential(iv)(_var)] = 0.0
311+
if is_time_dependent_parameter(_var, iv) &&
312+
!haskey(param_derivative_map, Differential(iv)(_var))
313+
# default to `nothing` since it is ignored during substitution,
314+
# so `D(_var)` is retained in the expression.
315+
param_derivative_map[Differential(iv)(_var)] = nothing
303316
end
304317
continue
305318
end
@@ -357,6 +370,9 @@ function TearingState(sys; quick_cancel = false, check = true)
357370
eqs[i] = eqs[i].lhs ~ rhs
358371
end
359372
end
373+
eqs = eqs[eqs_to_retain]
374+
neqs = length(eqs)
375+
symbolic_incidence = symbolic_incidence[eqs_to_retain]
360376

361377
### Handle discrete variables
362378
lowest_shift = Dict()

test/structural_transformation/utils.jl

+22-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,28 @@ end
302302
return ODESystem(eqs, t, vars, params; systems, name)
303303
end
304304

305-
@mtkbuild sys = FilteredInput()
305+
@component function FilteredInputFix(; name, x0 = 0, T = 0.1)
306+
params = @parameters begin
307+
k(t) = x0
308+
T = T
309+
end
310+
vars = @variables begin
311+
x(t) = k
312+
dx(t) = 0
313+
ddx(t)
314+
end
315+
systems = []
316+
eqs = [D(x) ~ dx
317+
D(dx) ~ ddx
318+
dx ~ (k - x) / T
319+
D(k) ~ 0]
320+
return ODESystem(eqs, t, vars, params; systems, name)
321+
end
322+
323+
@named sys = FilteredInput()
324+
@test_throws ["derivative of discrete variable", "k(t)"] structural_simplify(sys)
325+
326+
@mtkbuild sys = FilteredInputFix()
306327
vs = Set()
307328
for eq in equations(sys)
308329
ModelingToolkit.vars!(vs, eq)

0 commit comments

Comments
 (0)