Skip to content

Commit

Permalink
Merge pull request #151 from JuliaDynamics/hw/fix_observables
Browse files Browse the repository at this point in the history
fix reconstruction of p dependent observables
  • Loading branch information
hexaeder authored Oct 10, 2024
2 parents 8fcd42d + e49c304 commit a65b3e3
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
version:
# - 'lts' TODO: replace with new LTS (1.10 probably)
- 'lts'
- '1'
- 'pre'
os:
Expand Down
7 changes: 6 additions & 1 deletion benchmark/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ BMPATH = joinpath(NDPATH, "benchmark")
@info "Activate Benchmark environment"
import Pkg;
Pkg.activate(BMPATH);
Pkg.develop(; path=NDPATH);
if VERSION < v"1.11.0-0"
Pkg.develop(; path=NDPATH);
end
Pkg.precompile();

using PkgBenchmark
Expand Down Expand Up @@ -96,6 +98,9 @@ isdirty = with(LibGit2.isdirty, GitRepo(ndpath_tmp))
if isdirty
@info "Dirty directory, add everything to new commit!"
@assert realpath(pwd()) == realpath(ndpath_tmp) "Julia is in $(pwd()) not it $ndpath_tmp"
run(`git status`)
run(`git config --global user.email "[email protected]"`)
run(`git config --global user.name "Benchmark Bot"`)
run(`git checkout -b $(randstring(15))`)
run(`git add -A`)
run(`git commit -m "tmp commit for benchmarking"`)
Expand Down
5 changes: 4 additions & 1 deletion docs/examples/cascading_failure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using Graphs
using OrdinaryDiffEqTsit5
using DiffEqCallbacks
using Plots
using Test #hide
import SymbolicIndexingInterface as SII

#=
Expand Down Expand Up @@ -157,7 +158,9 @@ prob = ODEProblem(swing_network, uflat(u0), (0,6), copy(pflat(p));
callback=CallbackSet(trip_cb, trip_first_cb))
Main.test_execution_styles(prob) # testing all ex styles #src
sol = solve(prob, Tsit5());

## we want to test the reconstruction of the observables # hide
@test all(!iszero, sol(sol.t; idxs=eidxs(sol,:,:P))[begin]) # hide
@test all(iszero, sol(sol.t; idxs=eidxs(sol,:,:P))[end][[1:5...,7]]) # hide
nothing #hide

# Through the magic of symbolic indexing we can plot the power flows on all lines:
Expand Down
17 changes: 11 additions & 6 deletions src/symbolicindexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,11 @@ end
####
#### Timeseries parameter indexing
####
const DEFAULT_PARA_TS_IDX = 1
SII.is_timeseries_parameter(nw::Network, sni) = SII.is_parameter(nw::Network, sni)
function SII.timeseries_parameter_index(nw::Network, sni)
# NOTE: ALL parameters are lumped in timeseries with idx 1
SII.ParameterTimeseriesIndex.(1, SII.parameter_index.(nw, sni))
SII.ParameterTimeseriesIndex.(DEFAULT_PARA_TS_IDX, SII.parameter_index.(nw, sni))
end

function SII.get_all_timeseries_indexes(nw::Network, sym)
Expand All @@ -302,8 +303,10 @@ function SII.get_all_timeseries_indexes(nw::Network, sym)
# else
# return Set()
# end
if SII.is_timeseries_parameter(nw, sym)
return Set{Union{Int, SII.ContinuousTimeseries}}([SII.timeseries_parameter_index(nw, sym).timeseries_idx])
if !iszero(pdim(nw)) && SII.is_timeseries_parameter(nw, sym)
return Set{Union{Int, SII.ContinuousTimeseries}}([DEFAULT_PARA_TS_IDX])
elseif !iszero(pdim(nw)) && SII.is_observed(nw, sym)
return Set{Union{Int, SII.ContinuousTimeseries}}([SII.ContinuousTimeseries(), DEFAULT_PARA_TS_IDX])
else
return Set{Union{Int, SII.ContinuousTimeseries}}([SII.ContinuousTimeseries()])
end
Expand All @@ -316,7 +319,7 @@ end
function SII.with_updated_parameter_timeseries_values(nw::Network, params, args::Pair...)
@assert length(args) == 1 "Did not expect more than 1 timeseries here, please report issue."
tsidx, p = args[1]
@assert tsidx == 1 "Did not expect the passed timeseries to have other index then 1, please report issue."
@assert tsidx == DEFAULT_PARA_TS_IDX "Did not expect the passed timeseries to have other index then 1, please report issue."
params .= p
end

Expand All @@ -328,7 +331,7 @@ function SciMLBase.create_parameter_timeseries_collection(nw::Network, p::Abstra
end

function SciMLBase.get_saveable_values(nw::Network, p::AbstractVector, timeseries_idx)
@assert timeseries_idx == 1 # nothing else makes sense
@assert timeseries_idx == DEFAULT_PARA_TS_IDX # nothing else makes sense
copy(p)
end
"""
Expand All @@ -339,7 +342,7 @@ if the parameter values have changed. This will store a timeseries of said param
solution object, thus alowing us to recosntruct observables which depend on time-dependet variables.
"""
function save_parameters!(integrator::SciMLBase.DEIntegrator)
SciMLBase.save_discretes!(integrator, 1)
SciMLBase.save_discretes!(integrator, DEFAULT_PARA_TS_IDX)
end

####
Expand All @@ -352,6 +355,8 @@ function SII.is_observed(nw::Network, sni)
# if has colon check if all are observed OR variables and return true
# the observed function will handle the whole thing then
all(s -> SII.is_variable(nw, s) || SII.is_observed(nw, s), sni)
elseif sni isa AbstractVector
any(SII.is_observed.(Ref(nw), sni))
else
_is_observed(nw, sni)
end
Expand Down
11 changes: 7 additions & 4 deletions test/symbolicindexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ SII.all_variable_symbols(nw)
@test filter(s->SII.is_parameter(nw,s), SII.all_symbols(nw)) == SII.parameter_symbols(nw)
@test filter(s->SII.is_variable(nw,s), SII.all_symbols(nw)) == SII.variable_symbols(nw)

sol[EIndex(1,:P)]
sol[EIndex(2,:P)]
sol[EIndex(3,:P)]
@test sol[EIndex(1:3,:P)] == sol[[EIndex(1,:P),EIndex(2,:P),EIndex(3,:P)]]
# sol[obs] does not work, because obs has two timeseries: Continous and Discrete
# it is unclear, whether it should return onlye discre values or for all sol.t
@test_broken sol[EIndex(1,:P)]
@test_broken sol[EIndex(2,:P)]
@test_broken sol[EIndex(3,:P)]
@test_broken sol[EIndex(1:3,:P)] == sol[[EIndex(1,:P),EIndex(2,:P),EIndex(3,:P)]]
@test sol(sol.t, idxs=EIndex(1:3,:P)) == sol(sol.t, idxs=[EIndex(1,:P),EIndex(2,:P),EIndex(3,:P)])

####
#### more complex problem
Expand Down

0 comments on commit a65b3e3

Please sign in to comment.