Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 21 additions & 39 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter)
)
if event_occurred
if callback.condition === nothing
new_t = zero(typeof(integrator.t))
cb_t = integrator.t
else
if callback.interp_points != 0
top_t = ts[interp_index] # Top at the smallest
Expand All @@ -475,7 +475,7 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter)
if callback.rootfind != SciMLBase.NoRootFind && !isdiscrete(integrator.alg)
zero_func(abst, p = nothing) = get_condition(integrator, callback, abst)
if zero_func(top_t) == 0
Θ = top_t
cb_t = top_t
else
if integrator.event_last_time == counter &&
abs(zero_func(bottom_t)) <= 100abs(integrator.last_event_error) &&
Expand All @@ -489,34 +489,25 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter)
sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) &&
error("Double callback crossing floating point reducer errored. Report this issue.")
end
Θ = find_root(zero_func, (bottom_t, top_t), callback.rootfind)
cb_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind)
integrator.last_event_error = DiffEqBase.value(
ODE_DEFAULT_NORM(
zero_func(Θ), Θ
zero_func(cb_t), cb_t
)
)
end
#Θ = prevfloat(...)
# prevfloat guarantees that the new time is either 1 floating point
# numbers just before the event or directly at zero, but not after.
# If there's a barrier which is never supposed to be crossed,
# then this will ensure that
# The item never leaves the domain. Otherwise Roots.jl can return
# a float which is slightly after, making it out of the domain, causing
# havoc.
new_t = Θ - integrator.tprev
elseif interp_index != callback.interp_points && !isdiscrete(integrator.alg)
new_t = ts[interp_index] - integrator.tprev
cb_t = ts[interp_index]
else
# If no solve and no interpolants, just use endpoint
new_t = integrator.dt
cb_t = integrator.t
end
end
else
new_t = zero(typeof(integrator.t))
cb_t = integrator.t
end

return new_t, prev_sign, event_occurred, event_idx
return cb_t, prev_sign, event_occurred, event_idx
end

function find_callback_time(integrator, callback::VectorContinuousCallback, counter)
Expand All @@ -528,7 +519,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
)
if event_occurred
if callback.condition === nothing
new_t = zero(typeof(integrator.t))
cb_t = integrator.t
min_event_idx = findfirst(isequal(1), event_idx)
else
if callback.interp_points != 0
Expand All @@ -539,7 +530,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
bottom_t = integrator.tprev
end
if callback.rootfind != SciMLBase.NoRootFind && !isdiscrete(integrator.alg)
min_t = isforward(integrator) ? nextfloat(top_t) : prevfloat(top_t)
cb_t = isforward(integrator) ? nextfloat(top_t) : prevfloat(top_t)
min_event_idx = -1
for idx in 1:length(event_idx)
if ArrayInterface.allowed_getindex(event_idx, idx) != 0
Expand All @@ -553,7 +544,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
)
end
if zero_func(top_t) == 0
Θ = top_t
cbi_t = top_t
else
if integrator.event_last_time == counter &&
integrator.vector_event_last_time == idx &&
Expand All @@ -570,49 +561,40 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
error("Double callback crossing floating point reducer errored. Report this issue.")
end

Θ = find_root(zero_func, (bottom_t, top_t), callback.rootfind)
if integrator.tdir * Θ < integrator.tdir * min_t
cbi_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind)
if integrator.tdir * cbi_t < integrator.tdir * cb_t
integrator.last_event_error = DiffEqBase.value(
ODE_DEFAULT_NORM(
zero_func(Θ), Θ
zero_func(cbi_t), cbi_t
)
)
end
end
if integrator.tdir * Θ < integrator.tdir * min_t
if integrator.tdir * cbi_t < integrator.tdir * cb_t
min_event_idx = idx
min_t = Θ
cb_t = cbi_t
end
end
end
#Θ = prevfloat(...)
# prevfloat guarantees that the new time is either 1 floating point
# numbers just before the event or directly at zero, but not after.
# If there's a barrier which is never supposed to be crossed,
# then this will ensure that
# The item never leaves the domain. Otherwise Roots.jl can return
# a float which is slightly after, making it out of the domain, causing
# havoc.
new_t = min_t - integrator.tprev
elseif interp_index != callback.interp_points && !isdiscrete(integrator.alg)
new_t = ts[interp_index] - integrator.tprev
cb_t = ts[interp_index]
min_event_idx = findfirst(isequal(1), event_idx)
else
# If no solve and no interpolants, just use endpoint
new_t = integrator.dt
cb_t = integrator.t
min_event_idx = findfirst(isequal(1), event_idx)
end
end
else
new_t = zero(typeof(integrator.t))
cb_t = integrator.t
min_event_idx = 1
end

if event_occurred && min_event_idx < 0
error("Callback handling failed. Please file an issue with code to reproduce.")
end

return new_t, ArrayInterface.allowed_getindex(prev_sign, min_event_idx),
return cb_t, ArrayInterface.allowed_getindex(prev_sign, min_event_idx),
event_occurred::Bool, min_event_idx::Int
end

Expand All @@ -632,7 +614,7 @@ function apply_callback!(
end

change_t_via_interpolation!(
integrator, integrator.tprev + cb_time, Val{:false}, callback.initializealg
integrator, cb_time, Val{:false}, callback.initializealg
)

# handle saveat
Expand Down
10 changes: 5 additions & 5 deletions test/downstream/community_callback_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ first_t = findfirst(isequal(0.5), sol.t)
# https://github.com/SciML/DiffEqBase.jl/issues/1231
@testset "Successive callbacks in same integration step" begin
cb = ContinuousCallback(
(u, t, integrator) -> t - 0.0e-8,
(u, t, integrator) -> t - 0.0,
(integrator) -> push!(record, 0)
)

vcb = VectorContinuousCallback(
(out, u, t, integrator) -> out .= (t - 1.0e-8, t - 2.0e-8),
(out, u, t, integrator) -> out .= (t - 1.0e-8, t - 2.0e-8, t - 2.0e-7),
(integrator, event_index) -> push!(record, event_index),
2
3
)

f(u, p, t) = 1.0
Expand All @@ -269,12 +269,12 @@ first_t = findfirst(isequal(0.5), sol.t)
tspan = (-1.0, 1.0)
prob = ODEProblem(f, u0, tspan)
sol = solve(prob, Tsit5(), dt = 2.0, callback = CallbackSet(cb, vcb))
@test record == [0, 1, 2]
@test record == [0, 1, 2, 3]

# Backward propagation with successive events
record = []
tspan = (1.0, -1.0)
prob = ODEProblem(f, u0, tspan)
sol = solve(prob, Tsit5(), dt = 2.0, callback = CallbackSet(cb, vcb))
@test record == [2, 1, 0]
@test record == [3, 2, 1, 0]
end
Loading