diff --git a/src/callbacks.jl b/src/callbacks.jl index 7ad746bfb..f512de88f 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -463,7 +463,7 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter) ) if event_occurred if callback.condition === nothing - new_t = zero(typeof(integrator.t)) + cb_t = integrator.t else if callback.interp_points != 0 top_t = ts[interp_index] # Top at the smallest @@ -475,7 +475,7 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter) if callback.rootfind != SciMLBase.NoRootFind && !isdiscrete(integrator.alg) zero_func(abst, p = nothing) = get_condition(integrator, callback, abst) if zero_func(top_t) == 0 - Θ = top_t + cb_t = top_t else if integrator.event_last_time == counter && abs(zero_func(bottom_t)) <= 100abs(integrator.last_event_error) && @@ -489,34 +489,25 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter) sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) && error("Double callback crossing floating point reducer errored. Report this issue.") end - Θ = find_root(zero_func, (bottom_t, top_t), callback.rootfind) + cb_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) integrator.last_event_error = DiffEqBase.value( ODE_DEFAULT_NORM( - zero_func(Θ), Θ + zero_func(cb_t), cb_t ) ) end - #Θ = prevfloat(...) - # prevfloat guarantees that the new time is either 1 floating point - # numbers just before the event or directly at zero, but not after. - # If there's a barrier which is never supposed to be crossed, - # then this will ensure that - # The item never leaves the domain. Otherwise Roots.jl can return - # a float which is slightly after, making it out of the domain, causing - # havoc. - new_t = Θ - integrator.tprev elseif interp_index != callback.interp_points && !isdiscrete(integrator.alg) - new_t = ts[interp_index] - integrator.tprev + cb_t = ts[interp_index] else # If no solve and no interpolants, just use endpoint - new_t = integrator.dt + cb_t = integrator.t end end else - new_t = zero(typeof(integrator.t)) + cb_t = integrator.t end - return new_t, prev_sign, event_occurred, event_idx + return cb_t, prev_sign, event_occurred, event_idx end function find_callback_time(integrator, callback::VectorContinuousCallback, counter) @@ -528,7 +519,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun ) if event_occurred if callback.condition === nothing - new_t = zero(typeof(integrator.t)) + cb_t = integrator.t min_event_idx = findfirst(isequal(1), event_idx) else if callback.interp_points != 0 @@ -539,7 +530,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun bottom_t = integrator.tprev end if callback.rootfind != SciMLBase.NoRootFind && !isdiscrete(integrator.alg) - min_t = isforward(integrator) ? nextfloat(top_t) : prevfloat(top_t) + cb_t = isforward(integrator) ? nextfloat(top_t) : prevfloat(top_t) min_event_idx = -1 for idx in 1:length(event_idx) if ArrayInterface.allowed_getindex(event_idx, idx) != 0 @@ -553,7 +544,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun ) end if zero_func(top_t) == 0 - Θ = top_t + cbi_t = top_t else if integrator.event_last_time == counter && integrator.vector_event_last_time == idx && @@ -570,41 +561,32 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun error("Double callback crossing floating point reducer errored. Report this issue.") end - Θ = find_root(zero_func, (bottom_t, top_t), callback.rootfind) - if integrator.tdir * Θ < integrator.tdir * min_t + cbi_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) + if integrator.tdir * cbi_t < integrator.tdir * cb_t integrator.last_event_error = DiffEqBase.value( ODE_DEFAULT_NORM( - zero_func(Θ), Θ + zero_func(cbi_t), cbi_t ) ) end end - if integrator.tdir * Θ < integrator.tdir * min_t + if integrator.tdir * cbi_t < integrator.tdir * cb_t min_event_idx = idx - min_t = Θ + cb_t = cbi_t end end end - #Θ = prevfloat(...) - # prevfloat guarantees that the new time is either 1 floating point - # numbers just before the event or directly at zero, but not after. - # If there's a barrier which is never supposed to be crossed, - # then this will ensure that - # The item never leaves the domain. Otherwise Roots.jl can return - # a float which is slightly after, making it out of the domain, causing - # havoc. - new_t = min_t - integrator.tprev elseif interp_index != callback.interp_points && !isdiscrete(integrator.alg) - new_t = ts[interp_index] - integrator.tprev + cb_t = ts[interp_index] min_event_idx = findfirst(isequal(1), event_idx) else # If no solve and no interpolants, just use endpoint - new_t = integrator.dt + cb_t = integrator.t min_event_idx = findfirst(isequal(1), event_idx) end end else - new_t = zero(typeof(integrator.t)) + cb_t = integrator.t min_event_idx = 1 end @@ -612,7 +594,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun error("Callback handling failed. Please file an issue with code to reproduce.") end - return new_t, ArrayInterface.allowed_getindex(prev_sign, min_event_idx), + return cb_t, ArrayInterface.allowed_getindex(prev_sign, min_event_idx), event_occurred::Bool, min_event_idx::Int end @@ -632,7 +614,7 @@ function apply_callback!( end change_t_via_interpolation!( - integrator, integrator.tprev + cb_time, Val{:false}, callback.initializealg + integrator, cb_time, Val{:false}, callback.initializealg ) # handle saveat diff --git a/test/downstream/community_callback_tests.jl b/test/downstream/community_callback_tests.jl index 9b725dda4..ba6169225 100644 --- a/test/downstream/community_callback_tests.jl +++ b/test/downstream/community_callback_tests.jl @@ -251,14 +251,14 @@ first_t = findfirst(isequal(0.5), sol.t) # https://github.com/SciML/DiffEqBase.jl/issues/1231 @testset "Successive callbacks in same integration step" begin cb = ContinuousCallback( - (u, t, integrator) -> t - 0.0e-8, + (u, t, integrator) -> t - 0.0, (integrator) -> push!(record, 0) ) vcb = VectorContinuousCallback( - (out, u, t, integrator) -> out .= (t - 1.0e-8, t - 2.0e-8), + (out, u, t, integrator) -> out .= (t - 1.0e-8, t - 2.0e-8, t - 2.0e-7), (integrator, event_index) -> push!(record, event_index), - 2 + 3 ) f(u, p, t) = 1.0 @@ -269,12 +269,12 @@ first_t = findfirst(isequal(0.5), sol.t) tspan = (-1.0, 1.0) prob = ODEProblem(f, u0, tspan) sol = solve(prob, Tsit5(), dt = 2.0, callback = CallbackSet(cb, vcb)) - @test record == [0, 1, 2] + @test record == [0, 1, 2, 3] # Backward propagation with successive events record = [] tspan = (1.0, -1.0) prob = ODEProblem(f, u0, tspan) sol = solve(prob, Tsit5(), dt = 2.0, callback = CallbackSet(cb, vcb)) - @test record == [2, 1, 0] + @test record == [3, 2, 1, 0] end