Skip to content

Commit

Permalink
Merge pull request #146 from JuliaDynamics/hw/fixinit
Browse files Browse the repository at this point in the history
fix small things in initialization
  • Loading branch information
hexaeder authored Sep 30, 2024
2 parents 70988ab + 1598400 commit 0e41b2f
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 109 deletions.
5 changes: 4 additions & 1 deletion docs/src/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ coupling
### Component Metadata API
```@docs
metadata
get_metadata(::NetworkDynamics.ComponentFunction, ::Symbol)
has_metadata(::NetworkDynamics.ComponentFunction, ::Symbol)
get_metadata(::NetworkDynamics.ComponentFunction, ::Symbol)
set_metadata!(::NetworkDynamics.ComponentFunction, ::Symbol, ::Any)
has_graphelement
get_graphelement
set_graphelement!
```
### Per-Symbol Metadata API
```@docs
Expand Down
8 changes: 7 additions & 1 deletion docs/src/metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ Component metadata is a `Dict{Symbol,Any}` attached to each component to store v

To access the data, you can use the methods `has_metadata`, `get_metadata` and `set_metadata!` (see [Component Metadata API](@ref)).

Special uses: after [component wise initialization](@ref), the field `:init_residual` stores the residual vector of the nonlinear problem.
Special metadata:

- `:init_residual`: after [component wise initialization](@ref), this field stores the residual vector of the nonlinear problem.
- `:graphelement`: optional field to specialize the graphelement for each
component (`vidx`) for vertices, `(;src,dst)` named tuple of either vertex
names or vertex indices for edges. Has special accessors `has_/get_/set_graphelement`.


## Symbol Metadata
Each component stores symbol metadata. The symbol metadata is a `Dict{Symbol, Dict{Symbol, Any}}` which stores a metadate dict per symbol. Symbols are everything that appears in [`sym`](@ref), [`psym`](@ref), [`obssym`](@ref) and [`inputsym`](@ref).
Expand Down
29 changes: 18 additions & 11 deletions ext/ModelingToolkitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import NetworkDynamics: ODEVertex, StaticEdge

include("ModelingToolkitUtils.jl")

function ODEVertex(sys::ODESystem, inputs, outputs; verbose=false)
function ODEVertex(sys::ODESystem, inputs, outputs; verbose=false, name=getname(sys))
warn_events(sys)
inputs = inputs isa AbstractVector ? inputs : [inputs]
outputs = outputs isa AbstractVector ? outputs : [outputs]
Expand All @@ -37,11 +37,10 @@ function ODEVertex(sys::ODESystem, inputs, outputs; verbose=false)
obsf = gen.g_ip

mass_matrix = gen.mass_matrix
name = getname(sys)
ODEVertex(;f, sym, psym, depth, inputsym, obssym, obsf, mass_matrix, name)
end

function StaticEdge(sys::ODESystem, srcin, dstin, outputs, coupling; verbose=false)
function StaticEdge(sys::ODESystem, srcin, dstin, outputs, coupling; verbose=false, name=getname(sys))
warn_events(sys)
srcin = srcin isa AbstractVector ? srcin : [srcin]
dstin = dstin isa AbstractVector ? dstin : [dstin]
Expand All @@ -68,7 +67,6 @@ function StaticEdge(sys::ODESystem, srcin, dstin, outputs, coupling; verbose=fal
depth = coupling isa Fiducial ? Int(length(outputs)/2) : length(outputs)
obsf = gen.g_ip

name = getname(sys)
StaticEdge(;f, sym, psym, depth, inputsym_src, inputsym_dst, obssym, obsf, coupling, name)
end

Expand All @@ -85,22 +83,31 @@ function _get_metadata(sys, name)
end
return nt
end
if ModelingToolkit.hasdefault(sym)
def = ModelingToolkit.getdefault(sym)
alldefaults = defaults(sys)
if haskey(alldefaults, sym)
def = alldefaults[sym]
if def isa Symbolic
def = fixpoint_sub(def, defaults(sys))
def = fixpoint_sub(def, alldefaults)
end
def isa Symbolic && error("Could not resolve default $(ModelingToolkit.getdefault(sym)) for $name")
nt = (; nt..., default=def)
end
if ModelingToolkit.hasguess(sym)
guess = ModelingToolkit.getguess(sym)

# check for guess both in symbol metadata and in guesses of system
# fixes https://github.com/SciML/ModelingToolkit.jl/issues/3075
if ModelingToolkit.hasguess(sym) || haskey(ModelingToolkit.guesses(sys), sym)
guess = if ModelingToolkit.hasguess(sym)
ModelingToolkit.getguess(sym)
else
ModelingToolkit.guesses(sys)[sym]
end
if guess isa Symbolic
guess = fixpoint_sub(def, defaults(sys))
guess = fixpoint_sub(def, merge(defaults(sys), guesses(sys)))
end
guess isa Symbolic && error("Could not resolve guess $(ModelingToolkit.getguess(sym)) for $name")
nt = (; nt..., guess=guess)
end

if ModelingToolkit.hasbounds(sym)
nt = (; nt..., bounds=ModelingToolkit.getbounds(sym))
end
Expand All @@ -122,7 +129,7 @@ function generate_io_function(_sys, inputss::Tuple, outputs;
outputs = getproperty_symbolic.(Ref(_sys), outputs)

sys = if ModelingToolkit.iscomplete(_sys)
_sys
deepcopy(_sys)
else
_openinputs = setdiff(allinputs, Set(full_parameters(_sys)))
structural_simplify(_sys, (_openinputs, outputs); simplify=true)[1]
Expand Down
4 changes: 3 additions & 1 deletion src/NetworkDynamics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module NetworkDynamics
using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv
using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv,
SimpleGraph, SimpleDiGraph, add_edge!, has_edge
using TimerOutputs: @timeit_debug, reset_timer!

using ArgCheck: @argcheck
Expand Down Expand Up @@ -43,6 +44,7 @@ export has_default, get_default, set_default!
export has_guess, get_guess, set_guess!
export has_init, get_init, set_init!
export has_bounds, get_bounds, set_bounds!
export has_graphelement, get_graphelement, set_graphelement!
include("component_functions.jl")

export Network
Expand Down
55 changes: 54 additions & 1 deletion src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,17 @@ function _fill_defaults(T, kwargs)
_maybewrap!(dict, :obssym, Symbol)

symmetadata = get!(dict, :symmetadata, Dict{Symbol,Dict{Symbol,Any}}())
metadata = get!(dict, :metadata, Dict{Symbol,Any}())

metadata = try
convert(Dict{Symbol,Any}, get!(dict, :metadata, Dict{Symbol,Any}()))
catch e
throw(ArgumentError("Provided metadata keyword musst be a Dict{Symbol,Any}. Got $(repr(dict[:metadata]))."))
end

if haskey(dict, :graphelement)
ge = pop!(dict, :graphelement)
metadata[:graphelement] = ge
end

# sym & dim
haskey(dict, :dim) || haskey(dict, :sym) || throw(ArgumentError("Either `dim` or `sym` must be provided to construct $T."))
Expand Down Expand Up @@ -693,6 +703,20 @@ Returns if a `default` value if available, otherwise returns `init` value for sy
"""
get_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_init(c, sym)

#### default or guess
"""
has_default_or_guess(c::ComponentFunction, sym::Symbol)
Checks if a `default` or `guess` value is present for symbol `sym`.
"""
has_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_guess(c, sym)
"""
get_default_or_guess(c::ComponentFunction, sym::Symbol)
Returns if a `default` value if available, otherwise returns `guess` value for symbol `sym`.
"""
get_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_guess(c, sym)


# TODO: legacy, only used within show methods
function def(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
Expand Down Expand Up @@ -729,3 +753,32 @@ get_metadata(c::ComponentFunction, key::Symbol) = metadata(c)[key]
Sets the metadata `key` for the component to `value`.
"""
set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), val, key)

#### graphelement field for edges and vertices
"""
has_graphelement(c)
Checks if the edge or vetex function function has the `graphelement` metadata.
"""
has_graphelement(c::EdgeFunction) = has_metadata(c, :graphelement)
has_graphelement(c::VertexFunction) = has_metadata(c, :graphelement)
"""
get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T}
get_graphelement(c::VertexFunction)::Int
Retrieves the `graphelement` metadata for the component function. For edges this
returns a named tupe `(;src, dst)` where both are either integers (vertex index)
or symbols (vertex name).
"""
get_graphelement(c::EdgeFunction) = get_metadata(c, :graphelement)::@NamedTuple{src::T, dst::T} where {T<:Union{Int,Symbol}}
get_graphelement(c::VertexFunction) = get_metadata(c, :graphelement)::Int
"""
set_graphelement!(c::EdgeFunction, src, dst)
set_graphelement!(c::VertexFunction, vidx)
Sets the `graphelement` metadata for the edge function. For edges this takes two
arguments `src` and `dst` which are either integer (vertex index) or symbol
(vertex name). For vertices it takes a single integer `vidx`.
"""
set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt)
set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx)
90 changes: 90 additions & 0 deletions src/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@ function Network(g::AbstractGraph,
@argcheck length(_vertexf) == nv(g)
@argcheck length(_edgef) == ne(g)

# check if graphelement is set correctly, warn otherwise
for (i, v) in pairs(_vertexf)
if has_graphelement(v)
if get_graphelement(v) != i
@warn "Vertex function $v has wrong `:graphelement` $(get_graphelement(v)) != $i. \
Using this constructor the provided `:graphelement` is ignored!"
end
end
end
if any(has_graphelement, _edgef)
vnamedict = _unique_name_dict(_vertexf)
for (iteredge, ef) in zip(edges(g), _edgef)
if has_graphelement(ef)
ge = get_graphelement(ef)
if iteredge != _resolve_ge_to_edge(ge, vnamedict)
@warn "Edge function $ef has wrong `:graphelement` $(get_graphelement(ef)) != $iteredge. \
Using this constructor the provided `:graphelement` is ignored!"
end
end
end
end

verbose &&
println("Create dynamic network with $(nv(g)) vertices and $(ne(g)) edges:")
@argcheck execution isa ExecutionStyle "Execution type $execution not supported (choose from $(subtypes(ExecutionStyle)))"
Expand Down Expand Up @@ -96,6 +118,74 @@ function Network(g::AbstractGraph,
return nw
end

function Network(vertexfs, edgefs; kwargs...)
@argcheck all(has_graphelement, vertexfs) "All vertex functions must have assigned `graphelement` to implicitly construct graph!"
@argcheck all(has_graphelement, edgefs) "All edge functions must have assigned `graphelement` to implicitly construct graph!"

vidxs = get_graphelement.(vertexfs)
allunique(vidxs) || throw(ArgumentError("All vertex functions must have unique `graphelement`!"))
sort(vidxs) == 1:length(vidxs) || throw(ArgumentError("Vertex functions must have `graphelement` in range 1:length(vertexfs)!"))

vdict = Dict(vidxs .=> vertexfs)

vnamedict = _unique_name_dict(vertexfs)

simpleedges = map(edgefs) do e
ge = get_graphelement(e)
_resolve_ge_to_edge(ge, vnamedict)
end
allunique(simpleedges) || throw(ArgumentError("Some edge functions have the same `graphelement`!"))
edict = Dict(simpleedges .=> edgefs)

# if all src < dst then we can use SimpleGraph, else digraph
g = if all(e -> e.src < e.dst, simpleedges)
SimpleGraph(length(vertexfs))
else
SimpleDiGraph(length(vertexfs))
end
for edge in simpleedges
if g isa SimpleDiGraph && has_edge(g, edge.dst, edge.src)
@warn "Edges $(edge.src) -> $(edge.dst) and $(edge.dst) -> $(edge.src) are both present in the graph!"
end
r = add_edge!(g, edge)
r || error("Could not add edge $(edge) to graph $(g)!")
end

vfs_ordered = [vdict[k] for k in vertices(g)]
efs_ordered = [edict[k] for k in edges(g)]

Network(g, vfs_ordered, efs_ordered; kwargs...)
end

function _unique_name_dict(cfs::AbstractVector{<:ComponentFunction})
# find all names to resolve
names = getproperty.(cfs, :name)
dict = Dict(names .=> get_graphelement.(cfs))
# delete all names which occure multiple times
for i in eachindex(names)
if names[i] @views names[i+1:end]
delete!(dict, names[i])
end
end
dict
end
# resolve the graphelement ge (named tuple) to simple edge with potential lookup in vertex name dict dict
function _resolve_ge_to_edge(ge, vnamedict)
src = if ge.src isa Symbol
haskey(vnamedict, ge.src) || throw(ArgumentError("Edge function has unknown or non-unique source vertex name $(ge.src)"))
vnamedict[ge.src]
else
ge.src
end
dst = if ge.dst isa Symbol
haskey(vnamedict, ge.dst) || throw(ArgumentError("Edge function has unknown or non-unique source vertex name $(ge.dst)"))
vnamedict[ge.dst]
else
ge.dst
end
SimpleEdge(src, dst)
end

function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose)
components = @view im.vertexf[idxs]

Expand Down
19 changes: 14 additions & 5 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,20 @@ function initialization_problem(cf::T; t=NaN, verbose=true) where {T<:Union{ODEV
ufix = Float64[has_default(cf, s) ? get_default(cf, s) : NaN for s in sym(cf)]
pfix = Float64[has_default(cf, s) ? get_default(cf, s) : NaN for s in psym(cf)]

hasinputsym(cf) || throw(ArgumentError("Vertex musst have `inputsym` with default values!"))
input = if T <: EdgeFunction
(;src=Float64[get_default(cf, s) for s in inputsym(cf).src], dst=Float64[get_default(cf, s) for s in inputsym(cf).dst])
else
Float64[get_default(cf, s) for s in inputsym(cf)]
hasinputsym(cf) || throw(ArgumentError("Component function musst have `inputsym` with default values!"))

input= try
if T <: EdgeFunction
(;src=Float64[get_default(cf, s) for s in inputsym(cf).src], dst=Float64[get_default(cf, s) for s in inputsym(cf).dst])
else
Float64[get_default(cf, s) for s in inputsym(cf)]
end
catch e
if e isa KeyError
throw(ArgumentError("Component function musst have `inputsym` with default values!"))
else
rethrow(e)
end
end

freesym = vcat(sym(cf)[ufree_m], psym(cf)[pfree_m])
Expand Down
4 changes: 2 additions & 2 deletions test/AD_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ end
# jacobian(fp, AutoReverseDiff(), pflat(p0))
# jacobian(fp, AutoFiniteDiff(), pflat(p0))

scenarios = [JacobianScenario(fx; x=x0, y=fx(x0), nb_args=1, place=:inplace, jac=jacobian(fx, AutoFiniteDiff(), x0)),
JacobianScenario(fp; x=pflat(p0), y=fp(pflat(p0)), nb_args=1, place=:inplace, jac=jacobian(fp, AutoFiniteDiff(), pflat(p0)))]
scenarios = [Scenario{:jacobian, :in}(fx, x0; res1=jacobian(fx, AutoFiniteDiff(), x0)) ,
Scenario{:jacobian, :in}(fp, pflat(p0); res1=jacobian(fp, AutoFiniteDiff(), pflat(p0)))]
backends = [AutoForwardDiff(), AutoReverseDiff()]
test_differentiation(
backends, # the backends you want to compare
Expand Down
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
NetworkDynamics = {path = ".."}

[compat]
DifferentiationInterface = "0.6"
DifferentiationInterfaceTest = "0.7"
33 changes: 33 additions & 0 deletions test/construction_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,39 @@ using Graphs
@test_throws ArgumentError nd(_du, _u, rand(pdim(nd)+1), 0.0)
end

@testset "graphless constructor" begin
@test_throws ArgumentError ODEVertex(x->x^1, 2, 0; metadata="foba")
v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1)
@test has_graphelement(v1) && get_graphelement(v1) == 1
v2 = ODEVertex(x->x^2, 2, 0; name=:v2)
set_graphelement!(v2, 3)
v3 = ODEVertex(x->x^3, 2, 0; name=:v3)
set_graphelement!(v3, 2)

e1 = StaticEdge(nothing, 0, Symmetric(); graphelement=(;src=1,dst=2))
@test get_graphelement(e1) == (;src=1,dst=2)
e2 = StaticEdge(nothing, 0, Symmetric())
set_graphelement!(e2, (;src=:v3,dst=:v2))
e3 = StaticEdge(nothing, 0, Symmetric())

@test_throws ArgumentError Network([v1,v2,v3], [e1,e2,e3])
set_graphelement!(e3, (;src=3,dst=1))

nw = Network([v1,v2,v3], [e1,e2,e3])
@test nw.im.vertexf == [v1, v3, v2]
g = SimpleDiGraph(3)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 3, 1)
@test nw.im.g == g

set_graphelement!(e3, (;src=1,dst=2))
@test_throws ArgumentError Network([v1,v2,v3], [e1,e2,e3])

set_graphelement!(e3, (;src=2,dst=1))
Network([v1,v2,v3], [e1,e2,e3]) # throws waring about 1->2 and 2->1 beeing present
end

@testset "Vertex batch" begin
using NetworkDynamics: BatchStride, VertexBatch, parameter_range
vb = VertexBatch{ODEVertex, typeof(sum), Vector{Int}}([1, 2, 3, 4], # vertices
Expand Down
Loading

0 comments on commit 0e41b2f

Please sign in to comment.