Skip to content

Commit dfcd812

Browse files
committed
fix: fix sfmi bugs and
1 parent 1ee0844 commit dfcd812

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed

Diff for: src/systems/callbacks.jl

+17-7
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ struct SymbolicDiscreteCallback <: AbstractCallback
426426
reinitializealg::SciMLBase.DAEInitializationAlgorithm
427427

428428
function SymbolicDiscreteCallback(
429-
condition::Union{Number, Vector{<:Number}}, affect = nothing;
429+
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
430430
initialize = nothing, finalize = nothing,
431431
reinitializealg = nothing, kwargs...)
432432
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
@@ -468,10 +468,17 @@ function is_timed_condition(condition::T) where {T}
468468
end
469469
end
470470

471-
to_cb_vector(cbs::Vector{<:AbstractCallback}) = cbs
472-
to_cb_vector(cbs::Vector) = Vector{AbstractCallback}(cbs)
473-
to_cb_vector(cbs::Nothing) = AbstractCallback[]
474-
to_cb_vector(cb::AbstractCallback) = [cb]
471+
to_cb_vector(cbs::Vector{<:AbstractCallback}; kwargs...) = cbs
472+
to_cb_vector(cbs::Union{Nothing, Vector{Nothing}}; kwargs...) = AbstractCallback[]
473+
to_cb_vector(cb::AbstractCallback; kwargs...) = [cb]
474+
function to_cb_vector(cbs; CB_TYPE = SymbolicContinuousCallback, kwargs...)
475+
if cbs isa Pair
476+
@show cbs
477+
[CB_TYPE(cbs; kwargs...)]
478+
else
479+
Vector{CB_TYPE}([CB_TYPE(cb; kwargs...) for cb in cbs])
480+
end
481+
end
475482

476483
############################################
477484
########## Namespacing Utilities ###########
@@ -906,10 +913,10 @@ function compile_equational_affect(
906913

907914
u_up, u_up! = build_function_wrapper(sys, (@view rhss[is_u]), dvs, _ps..., t;
908915
wrap_code = add_integrator_header(sys, integ, :u),
909-
expression = Val{false}, outputidxs = u_idxs, wrap_mtkparameters)
916+
expression = Val{false}, outputidxs = u_idxs, wrap_mtkparameters, cse = false)
910917
p_up, p_up! = build_function_wrapper(sys, (@view rhss[is_p]), dvs, _ps..., t;
911918
wrap_code = add_integrator_header(sys, integ, :p),
912-
expression = Val{false}, outputidxs = p_idxs, wrap_mtkparameters)
919+
expression = Val{false}, outputidxs = p_idxs, wrap_mtkparameters, cse = false)
913920

914921
return function explicit_affect!(integ)
915922
isempty(dvs_to_update) || u_up!(integ)
@@ -934,7 +941,10 @@ function compile_equational_affect(
934941
end
935942
affprob = ImplicitDiscreteProblem(affsys, u0, (integ.t, integ.t), pmap;
936943
build_initializeprob = false, check_length = false)
944+
@show pmap
945+
@show u0
937946
affsol = init(affprob, IDSolve())
947+
@show affsol
938948
(check_error(affsol) === ReturnCode.InitialFailure) &&
939949
throw(UnsolvableCallbackError(all_equations(aff)))
940950
for u in dvs_to_update

Diff for: src/systems/diffeqs/odesystem.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
320320

321321
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
322322
deqs)
323-
cont_callbacks = to_cb_vector(SymbolicContinuousCallback.(
324-
continuous_events; alg_eqs = alg_eqs, iv = iv, warn_no_algebraic = false))
325-
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(
326-
discrete_events; alg_eqs = alg_eqs, iv = iv, warn_no_algebraic = false))
323+
cont_callbacks = to_cb_vector(continuous_events; CB_TYPE = SymbolicContinuousCallback, iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
324+
disc_callbacks = to_cb_vector(discrete_events; CB_TYPE = SymbolicDiscreteCallback, iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
327325

328326
if is_dde === nothing
329327
is_dde = _check_if_dde(deqs, iv′, systems)

Diff for: src/systems/diffeqs/sdesystem.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
272272

273273
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
274274
deqs)
275-
cont_callbacks = to_cb_vector(SymbolicContinuousCallback.(
276-
continuous_events; alg_eqs = alg_eqs, iv = iv, warn_no_algebraic = false))
277-
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(
278-
discrete_events; alg_eqs = alg_eqs, iv = iv, warn_no_algebraic = false))
275+
cont_callbacks = to_cb_vector(continuous_events; CB_TYPE = SymbolicContinuousCallback, iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
276+
disc_callbacks = to_cb_vector(discrete_events; CB_TYPE = SymbolicDiscreteCallback, iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
279277

280278
if is_dde === nothing
281279
is_dde = _check_if_dde(deqs, iv′, systems)

Diff for: src/systems/jumps/jumpsystem.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,8 @@ function JumpSystem(eqs, iv, unknowns, ps;
212212
end
213213
end
214214

215-
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(
216-
discrete_events; iv = iv, warn_no_algebraic = false))
217-
cont_callbacks = to_cb_vector(SymbolicContinuousCallback.(
218-
continuous_events; iv = iv, warn_no_algebraic = false))
215+
cont_callbacks = to_cb_vector(continuous_events; CB_TYPE = SymbolicContinuousCallback, iv = iv, warn_no_algebraic = false)
216+
disc_callbacks = to_cb_vector(discrete_events; CB_TYPE = SymbolicDiscreteCallback, iv = iv, warn_no_algebraic = false)
219217

220218
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
221219
ap, iv′, us′, ps′, var_to_name, observed, name, description, systems,

0 commit comments

Comments
 (0)