diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 19983a9c..3dcbb905 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,6 +1,6 @@ using Pkg -using BenchmarkTools +using Chairmarks using Graphs using NetworkDynamics using Serialization diff --git a/src/NetworkDynamics.jl b/src/NetworkDynamics.jl index 38e4c3e0..d5e2df16 100644 --- a/src/NetworkDynamics.jl +++ b/src/NetworkDynamics.jl @@ -1,7 +1,7 @@ module NetworkDynamics using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv, SimpleGraph, SimpleDiGraph, add_edge!, has_edge -using TimerOutputs: @timeit_debug, reset_timer! +using TimerOutputs: TimerOutputs, @timeit_debug, @timeit using ArgCheck: @argcheck using PreallocationTools: PreallocationTools, DiffCache, get_tmp @@ -39,12 +39,6 @@ export ODEVertex, StaticVertex, StaticEdge, ODEEdge export Symmetric, AntiSymmetric, Directed, Fiducial export dim, sym, pdim, psym, obssym, depth, hasinputsym, inputsym, coupling export metadata, symmetadata -export has_metadata, get_metadata, set_metadata! -export has_default, get_default, set_default! -export has_guess, get_guess, set_guess! -export has_init, get_init, set_init! -export has_bounds, get_bounds, set_bounds! -export has_graphelement, get_graphelement, set_graphelement! include("component_functions.jl") export Network @@ -64,6 +58,14 @@ export vidxs, eidxs, vpidxs, epidxs export save_parameters! include("symbolicindexing.jl") +export has_metadata, get_metadata, set_metadata! +export has_default, get_default, set_default! +export has_guess, get_guess, set_guess! +export has_init, get_init, set_init! +export has_bounds, get_bounds, set_bounds! +export has_graphelement, get_graphelement, set_graphelement! +include("metadata.jl") + using NonlinearSolve: AbstractNonlinearSolveAlgorithm, NonlinearFunction using NonlinearSolve: NonlinearLeastSquaresProblem, NonlinearProblem using SteadyStateDiffEq: SteadyStateProblem, SteadyStateDiffEqAlgorithm, SSRootfind diff --git a/src/aggregators.jl b/src/aggregators.jl index bd31e809..51911374 100644 --- a/src/aggregators.jl +++ b/src/aggregators.jl @@ -315,6 +315,14 @@ function aggregate!(a::SparseAggregator, aggbuf, data) nothing end +# functions to retrieve the constructor of an aggregator for remake of network +get_aggr_constructor(a::NaiveAggregator) = NaiveAggregator(a.f) +get_aggr_constructor(a::KAAggregator) = KAAggregator(a.f) +get_aggr_constructor(a::SequentialAggregator) = SequentialAggregator(a.f) +get_aggr_constructor(a::PolyesterAggregator) = PolyesterAggregator(a.f) +get_aggr_constructor(a::ThreadedAggregator) = ThreadedAggregator(a.f) +get_aggr_constructor(a::SparseAggregator) = SparseAggregator(+) + iscudacompatible(::Type{<:Aggregator}) = false iscudacompatible(::Type{<:KAAggregator}) = true iscudacompatible(::Type{<:SparseAggregator}) = true diff --git a/src/component_functions.jl b/src/component_functions.jl index c9befd5d..b9d3b1fb 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -190,6 +190,7 @@ ODEVertex(; kwargs...) = _construct_comp(ODEVertex, kwargs) ODEVertex(f; kwargs...) = ODEVertex(;f, kwargs...) ODEVertex(f, dim; kwargs...) = ODEVertex(;f, _dimsym(dim)..., kwargs...) ODEVertex(f, dim, pdim; kwargs...) = ODEVertex(;f, _dimsym(dim, pdim)..., kwargs...) +ODEVertex(v::ODEVertex; kwargs...) = _reconstruct_comp(ODEVertex, v, kwargs) struct StaticVertex{F,OF} <: VertexFunction @CommonFields @@ -198,6 +199,7 @@ StaticVertex(; kwargs...) = _construct_comp(StaticVertex, kwargs) StaticVertex(f; kwargs...) = StaticVertex(;f, kwargs...) StaticVertex(f, dim; kwargs...) = StaticVertex(;f, _dimsym(dim)..., kwargs...) StaticVertex(f, dim, pdim; kwargs...) = StaticVertex(;f, _dimsym(dim, pdim)..., kwargs...) +StaticVertex(v::StaticVertex; kwargs...) = _reconstruct_comp(StaticVertex, v, kwargs) function ODEVertex(sv::StaticVertex) d = Dict{Symbol,Any}() for prop in propertynames(sv) @@ -224,6 +226,7 @@ StaticEdge(; kwargs...) = _construct_comp(StaticEdge, kwargs) StaticEdge(f; kwargs...) = StaticEdge(;f, kwargs...) StaticEdge(f, dim, coupling; kwargs...) = StaticEdge(;f, _dimsym(dim)..., coupling, kwargs...) StaticEdge(f, dim, pdim, coupling; kwargs...) = StaticEdge(;f, _dimsym(dim, pdim)..., coupling, kwargs...) +StaticEdge(e::StaticEdge; kwargs...) = _reconstruct_comp(StaticEdge, e, kwargs) struct ODEEdge{C,F,OF,MM} <: EdgeFunction{C} @CommonFields @@ -234,6 +237,7 @@ ODEEdge(; kwargs...) = _construct_comp(ODEEdge, kwargs) ODEEdge(f; kwargs...) = ODEEdge(;f, kwargs...) ODEEdge(f, dim, coupling; kwargs...) = ODEEdge(;f, _dimsym(dim)..., coupling, kwargs...) ODEEdge(f, dim, pdim, coupling; kwargs...) = ODEEdge(;f, _dimsym(dim, pdim)..., coupling, kwargs...) +ODEEdge(e::ODEEdge; kwargs...) = _reconstruct_comp(ODEEdge, e, kwargs) statetype(::T) where {T<:ComponentFunction} = statetype(T) statetype(::Type{<:ODEVertex}) = Dynamic() @@ -313,6 +317,18 @@ function _construct_comp(::Type{T}, kwargs) where {T} return c end +function _reconstruct_comp(::Type{T}, cf::ComponentFunction, kwargs) where {T} + fields = fieldnames(T) + dict = Dict{Symbol, Any}() + for f in fields + dict[f] = getproperty(cf, f) + end + for (k, v) in kwargs + dict[k] = v + end + _construct_comp(T, dict) +end + """ _fill_defaults(T, kwargs) @@ -585,229 +601,30 @@ _valid_signature(::Type{<:ODEEdge}, f) = _takes_n_vectors(f, 5) #(du, u, src, ds _takes_n_vectors(f, n) = hasmethod(f, (Tuple(Vector{Float64} for i in 1:n)..., Float64)) - -#### -#### per sym metadata -#### -""" - has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) - -Checks if symbol metadata `key` is present for symbol `sym`. -""" -function has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) - md = symmetadata(c) - haskey(md, sym) && haskey(md[sym], key) -end -""" - get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) - -Retrievs the metadata `key` for symbol `sym`. -""" -get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) = symmetadata(c)[sym][key] - -""" - set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) - set_metadata!(c::ComponentFunction, sym::Symbol, pair) - -Sets the metadata `key` for symbol `sym` to `value`. -""" -function set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) - d = get!(symmetadata(c), sym, Dict{Symbol,Any}()) - d[key] = value -end -set_metadata!(c::ComponentFunction, sym::Symbol, pair::Pair) = set_metadata!(c, sym, pair.first, pair.second) - -#### default -""" - has_default(c::ComponentFunction, sym::Symbol) - -Checks if a `default` value is present for symbol `sym`. -""" -has_default(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :default) -""" - get_default(c::ComponentFunction, sym::Symbol) - -Returns the `default` value for symbol `sym`. -""" -get_default(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :default) -""" - set_default!(c::ComponentFunction, sym::Symbol, value) - -Sets the `default` value for symbol `sym` to `value`. -""" -set_default!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :default, value) - -#### guess -""" - has_guess(c::ComponentFunction, sym::Symbol) - -Checks if a `guess` value is present for symbol `sym`. -""" -has_guess(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :guess) -""" - get_guess(c::ComponentFunction, sym::Symbol) - -Returns the `guess` value for symbol `sym`. -""" -get_guess(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :guess) -""" - set_guess!(c::ComponentFunction, sym::Symbol, value) - -Sets the `guess` value for symbol `sym` to `value`. -""" -set_guess!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :guess, value) - -#### init -""" - has_init(c::ComponentFunction, sym::Symbol) - -Checks if a `init` value is present for symbol `sym`. -""" -has_init(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :init) -""" - get_init(c::ComponentFunction, sym::Symbol) - -Returns the `init` value for symbol `sym`. -""" -get_init(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :init) -""" - set_init!(c::ComponentFunction, sym::Symbol, value) - -Sets the `init` value for symbol `sym` to `value`. -""" -set_init!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :init, value) - -#### bounds -""" - has_bounds(c::ComponentFunction, sym::Symbol) - -Checks if a `bounds` value is present for symbol `sym`. """ -has_bounds(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :bounds) -""" - get_bounds(c::ComponentFunction, sym::Symbol) - -Returns the `bounds` value for symbol `sym`. -""" -get_bounds(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :bounds) -""" - set_bounds!(c::ComponentFunction, sym::Symbol, value) - -Sets the `bounds` value for symbol `sym` to `value`. -""" -set_bounds!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :bounds, value) - + copy(c::NetworkDynamics.ComponentFunction) -#### default or init +Shallow copy of the component function. Creates a deepcopy of `metadata` and `symmetadata` +but references the same objects everywhere else. """ - has_default_or_init(c::ComponentFunction, sym::Symbol) +@generated function Base.copy(c::ComponentFunction) + fields = fieldnames(c) + # fields to copy + cfields = (:metadata, :symmetadata) + # normal fields + nfields = setdiff(fields, cfields) + assign = Expr(:block, + (:($(field) = c.$field) for field in nfields)..., + (:($(field) = deepcopy(c.$field)) for field in cfields)...) + construct = Expr(:call, c, [:($field) for field in fields]...) -Checks if a `default` or `init` value is present for symbol `sym`. -""" -has_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_init(c, sym) -""" - get_default_or_init(c::ComponentFunction, sym::Symbol) - -Returns if a `default` value if available, otherwise returns `init` value for symbol `sym`. -""" -get_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_init(c, sym) - -#### default or guess -""" - has_default_or_guess(c::ComponentFunction, sym::Symbol) - -Checks if a `default` or `guess` value is present for symbol `sym`. -""" -has_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_guess(c, sym) -""" - get_default_or_guess(c::ComponentFunction, sym::Symbol) - -Returns if a `default` value if available, otherwise returns `guess` value for symbol `sym`. -""" -get_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_guess(c, sym) - - -# TODO: legacy, only used within show methods -function def(c::ComponentFunction)::Vector{Union{Nothing,Float64}} - map(c.sym) do s - has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing - end -end -function guess(c::ComponentFunction)::Vector{Union{Nothing,Float64}} - map(c.sym) do s - has_guess(c, s) ? get_guess(c, s) : nothing - end -end -function pdef(c::ComponentFunction)::Vector{Union{Nothing,Float64}} - map(c.psym) do s - has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing - end -end -function pguess(c::ComponentFunction)::Vector{Union{Nothing,Float64}} - map(c.psym) do s - has_guess(c, s) ? get_guess(c, s) : nothing + quote + $assign + $construct end end -#### -#### Component metadata -#### -""" - has_metadata(c::ComponentFunction, key::Symbol) - -Checks if metadata `key` is present for the component. -""" -function has_metadata(c::ComponentFunction, key) - haskey(metadata(c), key) -end -""" - get_metadata(c::ComponentFunction, key::Symbol) - -Retrieves the metadata `key` for the component. -""" -get_metadata(c::ComponentFunction, key::Symbol) = metadata(c)[key] -""" - set_metadata!(c::ComponentFunction, key::Symbol, value) - -Sets the metadata `key` for the component to `value`. -""" -set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), val, key) - -#### graphelement field for edges and vertices -""" - has_graphelement(c) - -Checks if the edge or vetex function function has the `graphelement` metadata. -""" -has_graphelement(c::ComponentFunction) = has_metadata(c, :graphelement) -""" - get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T} - get_graphelement(c::VertexFunction)::Int - -Retrieves the `graphelement` metadata for the component function. For edges this -returns a named tupe `(;src, dst)` where both are either integers (vertex index) -or symbols (vertex name). -""" -get_graphelement(c::EdgeFunction) = get_metadata(c, :graphelement)::@NamedTuple{src::T, dst::T} where {T<:Union{Int,Symbol}} -get_graphelement(c::VertexFunction) = get_metadata(c, :graphelement)::Int -""" - set_graphelement!(c::EdgeFunction, src, dst) - set_graphelement!(c::VertexFunction, vidx) - -Sets the `graphelement` metadata for the edge function. For edges this takes two -arguments `src` and `dst` which are either integer (vertex index) or symbol -(vertex name). For vertices it takes a single integer `vidx`. -""" -set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt) -set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx) - - -function get_defaults(c::ComponentFunction, syms) - [has_default(c, sym) ? get_default(c, sym) : nothing for sym in syms] -end -function get_guesses(c::ComponentFunction, syms) - [has_guess(c, sym) ? get_guess(c, sym) : nothing for sym in syms] -end -function get_defaults_or_inits(c::ComponentFunction, syms) - [has_default_or_init(c, sym) ? get_default_or_init(c, sym) : nothing for sym in syms] +Base.hash(cf::ComponentFunction, h::UInt) = hash_fields(cf, h) +function Base.:(==)(cf1::ComponentFunction, cf2::ComponentFunction) + typeof(cf1) == typeof(cf2) && equal_fields(cf1, cf2) end diff --git a/src/construction.jl b/src/construction.jl index ef9ef630..4b36a608 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -6,39 +6,29 @@ function Network(g::AbstractGraph, vdepth=:auto, aggregator=execution isa SequentialExecution ? SequentialAggregator(+) : PolyesterAggregator(+), check_graphelement=true, + dealias=false, verbose=false) - reset_timer!() + # TimerOutputs.reset_timer!() @timeit_debug "Construct Network" begin # collect all vertex/edgf to vector - _vertexf = vertexf isa Vector ? vertexf : [vertexf for _ in vertices(g)] - _edgef = edgef isa Vector ? edgef : [edgef for _ in edges(g)] + all_same_v = vertexf isa VertexFunction + all_same_e = edgef isa EdgeFunction + maybecopy = dealias ? copy : identity + _vertexf = all_same_v ? [maybecopy(vertexf) for _ in vertices(g)] : vertexf + _edgef = all_same_e ? [maybecopy(edgef) for _ in edges(g)] : edgef + @argcheck _vertexf isa Vector{<:VertexFunction} "Expected VertexFuncions, got $(eltype(_vertexf))" @argcheck _edgef isa Vector{<:EdgeFunction} "Expected EdgeFuncions, got $(eltype(_vertexf))" @argcheck length(_vertexf) == nv(g) @argcheck length(_edgef) == ne(g) - # check if graphelement is set correctly, warn otherwise - if check_graphelement - for (i, v) in pairs(_vertexf) - if has_graphelement(v) - if get_graphelement(v) != i - @warn "Vertex function $v has wrong `:graphelement` $(get_graphelement(v)) != $i. \ - Using this constructor the provided `:graphelement` is ignored!" - end - end - end - if any(has_graphelement, _edgef) - vnamedict = _unique_name_dict(_vertexf) - for (iteredge, ef) in zip(edges(g), _edgef) - if has_graphelement(ef) - ge = get_graphelement(ef) - if iteredge != _resolve_ge_to_edge(ge, vnamedict) - @warn "Edge function $ef has wrong `:graphelement` $(get_graphelement(ef)) != $iteredge. \ - Using this constructor the provided `:graphelement` is ignored!" - end - end - end - end + # check if components alias eachother copy if necessary + # allready dealiase if provided as single functions + if dealias && !all_same_v + dealias!(_vertexf) + end + if dealias && !all_same_e + dealias!(_edgef) end verbose && @@ -64,7 +54,48 @@ function Network(g::AbstractGraph, end # create index manager - im = IndexManager(g, dynstates, edepth, vdepth, _vertexf, _edgef) + @timeit_debug "Construct Index manager" begin + valias = dealias ? :none : (all_same_v ? :all : :some) + ealias = dealias ? :none : (all_same_e ? :all : :some) + im = IndexManager(g, dynstates, edepth, vdepth, _vertexf, _edgef; valias, ealias) + end + + + # check graph_element metadata and attach if necessary + if check_graphelement + @timeit_debug "Check graph element" begin + if !all_same_v + for (i, vf) in pairs(_vertexf) + if has_graphelement(vf) + if get_graphelement(vf) != i + @warn "Vertex function $(vf.name) is placed at node index $i bus has \ + `graphelement` $(get_graphelement(vf)) stored in metadata. \ + The wrong data will be ignored! Use `check_graphelement=false` tu supress this warning." + end + end + end + elseif has_graphelement(vertexf) + @warn "Provided vertex function has assigned `graphelement` metadata. \ + but is used at every vertex. The `graphelement` will be ignored." + end + if !all_same_e + for (iteredge, ef) in zip(im.edgevec, _edgef) + if has_graphelement(ef) + ge = get_graphelement(ef) + src = get(im.unique_vnames, ge.src, ge.src) + dst = get(im.unique_vnames, ge.dst, ge.dst) + if iteredge.src != src || iteredge.dst != dst + @warn "Edge function $(ef.name) at $(iteredge.src) => $(iteredge.dst) has wrong `:graphelement` $src => $dst). \ + The wrong data will be ignored! Use `check_graphelement=false` tu supress this warning." + end + end + end + elseif has_graphelement(edgef) + @warn "Provided edge function has assigned `graphelement` metadata. \ + but is used for all edges. The `graphelement` will be ignored." + end + end + end # batch identical edge and vertex functions @timeit_debug "batch identical vertexes" begin @@ -129,7 +160,7 @@ function Network(g::AbstractGraph, ) end - # print_timer() + # TimerOutputs.print_timer() return nw end @@ -143,11 +174,17 @@ function Network(vertexfs, edgefs; kwargs...) vdict = Dict(vidxs .=> vertexfs) - vnamedict = _unique_name_dict(vertexfs) + # find unique maapings from name => graphelement + vnamedict = unique_mappings(getproperty.(vertexfs, :name), get_graphelement.(vertexfs)) simpleedges = map(edgefs) do e ge = get_graphelement(e) - _resolve_ge_to_edge(ge, vnamedict) + src = get(vnamedict, ge.src, ge.src) + dst = get(vnamedict, ge.dst, ge.dst) + if src isa Symbol || dst isa Symbol + throw(ArgumentError("Edge graphelement $src => $dst continas non-unique or unknown vertex names!")) + end + SimpleEdge(src, dst) end allunique(simpleedges) || throw(ArgumentError("Some edge functions have the same `graphelement`!")) edict = Dict(simpleedges .=> edgefs) @@ -169,36 +206,44 @@ function Network(vertexfs, edgefs; kwargs...) vfs_ordered = [vdict[k] for k in vertices(g)] efs_ordered = [edict[k] for k in edges(g)] - Network(g, vfs_ordered, efs_ordered; kwargs...) + Network(g, vfs_ordered, efs_ordered; check_graphelement=false, kwargs...) end -function _unique_name_dict(cfs::AbstractVector{<:ComponentFunction}) - # find all names to resolve - names = getproperty.(cfs, :name) - dict = Dict(cf.name => get_graphelement(cf) for cf in cfs if has_graphelement(cf)) - # delete all names which occure multiple times - for i in eachindex(names) - if names[i] ∈ @views names[i+1:end] - delete!(dict, names[i]) - end - end - dict +""" + dealias!(cfs::Vector{<:ComponentFunction}) + +Checks if any component functions reference the same metadtata/symmetada fields and +creates copies of them if necessary. +""" +function dealias!(cfs::Vector{<:ComponentFunction}; warn=false) + ag = aliasgroups(cfs) + + isempty(ag) && return cfs # nothign to do + + copyidxs = reduce(vcat, values(ag)) + + cfs[copyidxs] = copy.(cfs[copyidxs]) end -# resolve the graphelement ge (named tuple) to simple edge with potential lookup in vertex name dict dict -function _resolve_ge_to_edge(ge, vnamedict) - src = if ge.src isa Symbol - haskey(vnamedict, ge.src) || throw(ArgumentError("Edge function has unknown or non-unique source vertex name $(ge.src)")) - vnamedict[ge.src] - else - ge.src - end - dst = if ge.dst isa Symbol - haskey(vnamedict, ge.dst) || throw(ArgumentError("Edge function has unknown or non-unique source vertex name $(ge.dst)")) - vnamedict[ge.dst] - else - ge.dst + + +""" + aliasgroups(cfs::Vector{<:ComponentFunction}) + +Returns a dict `cf => idxs` which contains all the component functions +which appear multiple times with all their indices. +""" +function aliasgroups(cfs::Vector{T}) where {T<:ComponentFunction} + d = IdDict{T, Vector{Int}}() + + for (i, cf) in pairs(cfs) + if haskey(d, cf) # c allready present + push!(d[cf], i) + else + d[cf] = [i] + end end - SimpleEdge(src, dst) + + filter!(x -> length(x.second) > 1, d) end function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose) @@ -340,11 +385,31 @@ function _check_massmatrix(c) end """ - Network(nw::Network; kwargs...) + Network(nw::Network; g, vertexf, edgef, kwargs...) Rebuild the Network with same graph and vertex/edge functions but possibly different kwargs. -# FIXME : needs to take all Network kw arguments into acount! """ -function Network(nw::Network; kwargs...) - Network(nw.im.g, nw.im.vertexf, nw.im.edgef; kwargs...) +function Network(nw::Network; + g = nw.im.g, + vertexf = copy.(nw.im.vertexf), + edgef = copy.(nw.im.edgef), + kwargs...) + + _kwargs = Dict(:execution => executionstyle(nw), + :edepth => :auto, + :vdepth => :auto, + :aggregator => get_aggr_constructor(nw.layer.aggregator), + :check_graphelement => true, + :dealias => false, + :verbose => false) + for (k, v) in kwargs + _kwargs[k] = v + end + + # check, that we actually provide all of the arguments + # mainly so we don't forget to add it here if we introduce new kw arg to main constructor + m = only(methods(Network, [typeof(g), typeof(vertexf), typeof(edgef)])) + @assert keys(_kwargs) == Set(Base.kwarg_decl(m)) + + Network(g, vertexf, edgef; _kwargs...) end diff --git a/src/metadata.jl b/src/metadata.jl new file mode 100644 index 00000000..b23d2c30 --- /dev/null +++ b/src/metadata.jl @@ -0,0 +1,223 @@ +#### +#### per sym metadata +#### +""" + has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) + +Checks if symbol metadata `key` is present for symbol `sym`. +""" +function has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) + md = symmetadata(c) + haskey(md, sym) && haskey(md[sym], key) +end +""" + get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) + +Retrievs the metadata `key` for symbol `sym`. +""" +get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) = symmetadata(c)[sym][key] + +""" + set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) + set_metadata!(c::ComponentFunction, sym::Symbol, pair) + +Sets the metadata `key` for symbol `sym` to `value`. +""" +function set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) + d = get!(symmetadata(c), sym, Dict{Symbol,Any}()) + d[key] = value +end +set_metadata!(c::ComponentFunction, sym::Symbol, pair::Pair) = set_metadata!(c, sym, pair.first, pair.second) + +# generate default methods for some per-symbol metadata fields +for md in [:default, :guess, :init, :bounds] + fname_has = Symbol(:has_, md) + fname_get = Symbol(:get_, md) + fname_set = Symbol(:set_, md, :!) + @eval begin + """ + has_$($(QuoteNode(md)))(c::ComponentFunction, sym::Symbol) + + Checks if a `$($(QuoteNode(md)))` value is present for symbol `sym`. + + See also [`get_$($(QuoteNode(md)))`](@ref), [`set_$($(QuoteNode(md)))`](@ref). + """ + $fname_has(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, $(QuoteNode(md))) + + """ + get_$($(QuoteNode(md)))(c::ComponentFunction, sym::Symbol) + + Returns the `$($(QuoteNode(md)))` value for symbol `sym`. + + See also [`has_$($(QuoteNode(md)))`](@ref), [`set_$($(QuoteNode(md)))`](@ref). + """ + $fname_get(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, $(QuoteNode(md))) + + + """ + set_$($(QuoteNode(md)))(c::ComponentFunction, sym::Symbol, value) + + Sets the `$($(QuoteNode(md)))` value for symbol `sym` to `value`. + + See also [`has_$($(QuoteNode(md)))`](@ref), [`get_$($(QuoteNode(md)))`](@ref). + """ + $fname_set(c::ComponentFunction, sym::Symbol, val) = set_metadata!(c, sym, $(QuoteNode(md)), val) + end +end +set_graphelement!(c::EdgeFunction, p::Pair) = set_graphelement!(c, (;src=p.first, dst=p.second)) + + +#### default or init +""" + has_default_or_init(c::ComponentFunction, sym::Symbol) + +Checks if a `default` or `init` value is present for symbol `sym`. +""" +has_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_init(c, sym) +""" + get_default_or_init(c::ComponentFunction, sym::Symbol) + +Returns if a `default` value if available, otherwise returns `init` value for symbol `sym`. +""" +get_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_init(c, sym) + +#### default or guess +""" + has_default_or_guess(c::ComponentFunction, sym::Symbol) + +Checks if a `default` or `guess` value is present for symbol `sym`. +""" +has_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_guess(c, sym) +""" + get_default_or_guess(c::ComponentFunction, sym::Symbol) + +Returns if a `default` value if available, otherwise returns `guess` value for symbol `sym`. +""" +get_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_guess(c, sym) + + +# TODO: legacy, only used within show methods +function def(c::ComponentFunction)::Vector{Union{Nothing,Float64}} + map(c.sym) do s + has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing + end +end +function guess(c::ComponentFunction)::Vector{Union{Nothing,Float64}} + map(c.sym) do s + has_guess(c, s) ? get_guess(c, s) : nothing + end +end +function pdef(c::ComponentFunction)::Vector{Union{Nothing,Float64}} + map(c.psym) do s + has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing + end +end +function pguess(c::ComponentFunction)::Vector{Union{Nothing,Float64}} + map(c.psym) do s + has_guess(c, s) ? get_guess(c, s) : nothing + end +end + +#### +#### Component metadata +#### +""" + has_metadata(c::ComponentFunction, key::Symbol) + +Checks if metadata `key` is present for the component. +""" +function has_metadata(c::ComponentFunction, key) + haskey(metadata(c), key) +end +""" + get_metadata(c::ComponentFunction, key::Symbol) + +Retrieves the metadata `key` for the component. +""" +get_metadata(c::ComponentFunction, key::Symbol) = metadata(c)[key] +""" + set_metadata!(c::ComponentFunction, key::Symbol, value) + +Sets the metadata `key` for the component to `value`. +""" +set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), val, key) + +#### graphelement field for edges and vertices +""" + has_graphelement(c) + +Checks if the edge or vetex function function has the `graphelement` metadata. +""" +has_graphelement(c::ComponentFunction) = has_metadata(c, :graphelement) +""" + get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T} + get_graphelement(c::VertexFunction)::Int + +Retrieves the `graphelement` metadata for the component function. For edges this +returns a named tupe `(;src, dst)` where both are either integers (vertex index) +or symbols (vertex name). +""" +get_graphelement(c::EdgeFunction) = get_metadata(c, :graphelement)::@NamedTuple{src::T, dst::T} where {T<:Union{Int,Symbol}} +get_graphelement(c::VertexFunction) = get_metadata(c, :graphelement)::Int +""" + set_graphelement!(c::EdgeFunction, src, dst) + set_graphelement!(c::VertexFunction, vidx) + +Sets the `graphelement` metadata for the edge function. For edges this takes two +arguments `src` and `dst` which are either integer (vertex index) or symbol +(vertex name). For vertices it takes a single integer `vidx`. +""" +set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt) +set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx) + + +function get_defaults(c::ComponentFunction, syms) + [has_default(c, sym) ? get_default(c, sym) : nothing for sym in syms] +end +function get_guesses(c::ComponentFunction, syms) + [has_guess(c, sym) ? get_guess(c, sym) : nothing for sym in syms] +end +function get_defaults_or_inits(c::ComponentFunction, syms) + [has_default_or_init(c, sym) ? get_default_or_init(c, sym) : nothing for sym in syms] +end + + +#### +#### Metadata Accessors through Network +#### +function aliased_changed(nw::Network; warn=true) + vchanged = _has_changed_hash(nw.im.aliased_vertexfs) + echanged = _has_changed_hash(nw.im.aliased_edgefs) + changed = vchanged || echanged + if changed && warn + s = if vchanged && echanged + "vertices and edges" + elseif vchanged + "vertices" + else + "edges" + end + @warn """ + The metadata of at least one of your aliased $s changed! Proceed with caution! + + Some edgef/vertexf provided to to the `Network` constructor alias eachother. + Which means, the Network object references the same component function in + multiple places. Thus, metadata changes (such as changing of default values or + component initialization) will be reflected in multiple components. To prevent + this use the `dealias=true` keyword or manualy `copy` edge/vertex functions + before creating the network. + """ + end + changed +end +function _has_changed_hash(aliased_cfs) + isempty(aliased_cfs) && return false + changed = false + for (k, v) in aliased_cfs + if hash(k) != v.hash + changed = true + break + end + end + changed +end diff --git a/src/network_structure.jl b/src/network_structure.jl index 2112444b..f5e310d1 100644 --- a/src/network_structure.jl +++ b/src/network_structure.jl @@ -22,15 +22,42 @@ mutable struct IndexManager{G} lastidx_gbuf::Int vertexf::Vector{VertexFunction} edgef::Vector{EdgeFunction} - function IndexManager(g, dyn_states, edepth, vdepth, vertexf, edgef) + aliased_vertexfs::IdDict{VertexFunction, @NamedTuple{idxs::Vector{Int}, hash::UInt}} + aliased_edgefs::IdDict{EdgeFunction, @NamedTuple{idxs::Vector{Int}, hash::UInt}} + unique_vnames::Dict{Symbol,Int} + unique_enames::Dict{Symbol,Int} + function IndexManager(g, dyn_states, edepth, vdepth, vertexf, edgef; valias, ealias) + aliased_vertexf_hashes = _aliased_hashes(VertexFunction, vertexf, valias) + aliased_edgef_hashes = _aliased_hashes(EdgeFunction, edgef, ealias) + unique_vnames = unique_mappings(getproperty.(vertexf, :name), 1:nv(g)) + unique_enames = unique_mappings(getproperty.(edgef, :name), 1:ne(g)) new{typeof(g)}(g, collect(edges(g)), (Vector{UnitRange{Int}}(undef, nv(g)) for i in 1:3)..., (Vector{UnitRange{Int}}(undef, ne(g)) for i in 1:5)..., edepth, vdepth, 0, dyn_states, 0, 0, 0, - vertexf, edgef) + vertexf, edgef, + aliased_vertexf_hashes, + aliased_edgef_hashes, + unique_vnames, + unique_enames) end end +function _aliased_hashes(T, cfs, aliastype) + hashdict = IdDict{T, @NamedTuple{idxs::Vector{Int}, hash::UInt}}() + if aliastype == :some + ag = aliasgroups(cfs) + for (c, idxs) in ag + h = hash(c) + hashdict[c] = (; idxs=idxs, hash=h) + end + elseif aliastype == :all + c = first(cfs) + hashdict[c] =(; idxs=collect(eachindex(cfs)), hash=hash(c)) + end + hashdict +end + dim(im::IndexManager) = im.lastidx_dynamic pdim(im::IndexManager) = im.lastidx_p sdim(im::IndexManager) = im.lastidx_static - im.lastidx_dynamic diff --git a/src/symbolicindexing.jl b/src/symbolicindexing.jl index 328e16c6..893a33b3 100644 --- a/src/symbolicindexing.jl +++ b/src/symbolicindexing.jl @@ -91,20 +91,35 @@ SSI Maintainer assured that f.sys is really only used for symbolic indexig so me SciMLBase.__has_sys(nw::Network) = true Base.getproperty(nw::Network, s::Symbol) = s===:sys ? nw : getfield(nw, s) -SII.symbolic_type(::Type{<:SymbolicIndex{Int,<:Union{Symbol,Int}}}) = SII.ScalarSymbolic() +SII.symbolic_type(::Type{<:SymbolicIndex{<:Union{Symbol,Int},<:Union{Symbol,Int}}}) = SII.ScalarSymbolic() SII.symbolic_type(::Type{<:SymbolicIndex}) = SII.ArraySymbolic() SII.hasname(::SymbolicIndex) = false -SII.hasname(::SymbolicIndex{Int,<:Union{Symbol,Int}}) = true -SII.getname(x::SymbolicVertexIndex) = Symbol("v$(x.compidx)₊$(x.subidx)") -SII.getname(x::SymbolicEdgeIndex) = Symbol("e$(x.compidx)₊$(x.subidx)") +SII.hasname(::SymbolicIndex{<:Union{Symbol,Int},<:Union{Symbol,Int}}) = true +function SII.getname(x::SymbolicVertexIndex) + prefix = x.compidx isa Int ? :v : Symbol() + Symbol(prefix, Symbol(x.compidx), :₊, Symbol(x.subidx)) +end +function SII.getname(x::SymbolicEdgeIndex) + prefix = x.compidx isa Int ? :e : Symbol() + Symbol(prefix, Symbol(x.compidx), :₊, Symbol(x.subidx)) +end -getcomp(nw::Network, sni::Union{EIndex{Int},EPIndex{Int}}) = nw.im.edgef[sni.compidx] -getcomp(nw::Network, sni::Union{VIndex{Int},VPIndex{Int}}) = nw.im.vertexf[sni.compidx] -getcomprange(nw::Network, sni::VIndex{Int}) = nw.im.v_data[sni.compidx] -getcomprange(nw::Network, sni::EIndex{Int}) = nw.im.e_data[sni.compidx] -getcompprange(nw::Network, sni::VPIndex{Int}) = nw.im.v_para[sni.compidx] -getcompprange(nw::Network, sni::EPIndex{Int}) = nw.im.e_para[sni.compidx] +resolvecompidx(nw::Network, sni::SymbolicIndex{Int}) = sni.compidx +function resolvecompidx(nw::Network, sni::SymbolicIndex{Symbol}) + dict = sni isa SymbolicVertexIndex ? nw.im.unique_vnames : nw.im.unique_enames + if haskey(dict, sni.compidx) + return dict[sni.compidx] + else + throw(ArgumentError("Could not resolve component index for $sni, the name might not be unique?")) + end +end +getcomp(nw::Network, sni::SymbolicEdgeIndex) = nw.im.edgef[resolvecompidx(nw, sni)] +getcomp(nw::Network, sni::SymbolicVertexIndex) = nw.im.vertexf[resolvecompidx(nw, sni)] +getcomprange(nw::Network, sni::VIndex{<:Union{Symbol,Int}}) = nw.im.v_data[resolvecompidx(nw, sni)] +getcomprange(nw::Network, sni::EIndex{<:Union{Symbol,Int}}) = nw.im.e_data[resolvecompidx(nw, sni)] +getcompprange(nw::Network, sni::VPIndex{<:Union{Symbol,Int}}) = nw.im.v_para[resolvecompidx(nw, sni)] +getcompprange(nw::Network, sni::EPIndex{<:Union{Symbol,Int}}) = nw.im.e_para[resolvecompidx(nw, sni)] subsym_has_idx(sym::Symbol, syms) = sym ∈ syms subsym_has_idx(idx::Int, syms) = 1 ≤ idx ≤ length(syms) @@ -114,7 +129,8 @@ subsym_to_idx(idx::Int, _) = idx #### #### Iterator/Broadcast interface for ArraySymbolic types #### -Base.broadcastable(si::SymbolicIndex{<:Union{Int,Colon},<:Union{Int,Symbol,Colon}}) = Ref(si) +# TODO: not broadcasting over idx with colon is weird +Base.broadcastable(si::SymbolicIndex{<:Union{Int,Symbol,Colon},<:Union{Int,Symbol,Colon}}) = Ref(si) const _IterableComponent = SymbolicIndex{<:Union{AbstractVector,Tuple},<:Union{Int,Symbol}} Base.length(si::_IterableComponent) = length(si.compidx) @@ -137,7 +153,7 @@ function Base.iterate(si::_IterableComponent, state=nothing) _similar(si, it[1], si.subidx), it[2] end -const _IterableSubcomponent = SymbolicIndex{Int,<:Union{AbstractVector,Tuple}} +const _IterableSubcomponent = SymbolicIndex{<:Union{Symbol,Int},<:Union{AbstractVector,Tuple}} Base.length(si::_IterableSubcomponent) = length(si.subidx) Base.size(si::_IterableSubcomponent) = (length(si),) Base.IteratorSize(si::_IterableSubcomponent) = Base.HasShape{1}() @@ -170,10 +186,10 @@ _resolve_colon(nw::Network, sni::VIndex{Colon}) = VIndex(1:nv(nw), sni.subidx) _resolve_colon(nw::Network, sni::EIndex{Colon}) = EIndex(1:ne(nw), sni.subidx) _resolve_colon(nw::Network, sni::VPIndex{Colon}) = VPIndex(1:nv(nw), sni.subidx) _resolve_colon(nw::Network, sni::EPIndex{Colon}) = EPIndex(1:ne(nw), sni.subidx) -_resolve_colon(nw::Network, sni::VIndex{Int,Colon}) = VIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni))) -_resolve_colon(nw::Network, sni::EIndex{Int,Colon}) = EIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni))) -_resolve_colon(nw::Network, sni::VPIndex{Int,Colon}) = VPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni))) -_resolve_colon(nw::Network, sni::EPIndex{Int,Colon}) = EPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni))) +_resolve_colon(nw::Network, sni::VIndex{<:Union{Symbol,Int},Colon}) = VIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni))) +_resolve_colon(nw::Network, sni::EIndex{<:Union{Symbol,Int},Colon}) = EIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni))) +_resolve_colon(nw::Network, sni::VPIndex{<:Union{Symbol,Int},Colon}) = VPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni))) +_resolve_colon(nw::Network, sni::EPIndex{<:Union{Symbol,Int},Colon}) = EPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni))) #### Implmentation of index provider interface @@ -195,13 +211,13 @@ function SII.is_variable(nw::Network, sni) if _hascolon(sni) SII.is_variable(nw, _resolve_colon(nw,sni)) elseif SII.symbolic_type(sni) === SII.ArraySymbolic() - all(s -> SII.is_variable(nw, s), sni) + all(Base.Fix1(SII.is_variable, nw), sni) else _is_variable(nw, sni) end end _is_variable(nw::Network, sni) = false -function _is_variable(nw::Network, sni::SymbolicStateIndex{Int,<:Union{Int,Symbol}}) +function _is_variable(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) return isdynamic(cf) && subsym_has_idx(sni.subidx, sym(cf)) end @@ -215,7 +231,7 @@ function SII.variable_index(nw::Network, sni) _variable_index(nw, sni) end end -function _variable_index(nw::Network, sni::SymbolicStateIndex{Int,<:Union{Int,Symbol}}) +function _variable_index(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) range = getcomprange(nw, sni) range[subsym_to_idx(sni.subidx, sym(cf))] @@ -242,14 +258,14 @@ function SII.is_parameter(nw::Network, sni) if _hascolon(sni) SII.is_parameter(nw, _resolve_colon(nw,sni)) elseif SII.symbolic_type(sni) === SII.ArraySymbolic() - all(s -> SII.is_parameter(nw, s), sni) + all(Base.Fix1(SII.is_parameter, nw), sni) else _is_parameter(nw, sni) end end _is_parameter(nw::Network, sni) = false function _is_parameter(nw::Network, - sni::SymbolicParameterIndex{Int,<:Union{Int,Symbol}}) + sni::SymbolicParameterIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) return subsym_has_idx(sni.subidx, psym(cf)) end @@ -264,7 +280,7 @@ function SII.parameter_index(nw::Network, sni) end end function _parameter_index(nw::Network, - sni::SymbolicParameterIndex{Int,<:Union{Int,Symbol}}) + sni::SymbolicParameterIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) range = getcompprange(nw, sni) range[subsym_to_idx(sni.subidx, psym(cf))] @@ -272,7 +288,6 @@ end function SII.parameter_symbols(nw::Network) syms = Vector{SymbolicParameterIndex{Int,Symbol}}(undef, pdim(nw)) - i = 1 for (ci, cf) in pairs(nw.im.vertexf) syms[nw.im.v_para[ci]] .= VPIndex.(ci, psym(cf)) end @@ -362,7 +377,7 @@ function SII.is_observed(nw::Network, sni) end end _is_observed(nw::Network, _) = false -function _is_observed(nw::Network, sni::SymbolicStateIndex{Int,<:Union{Int,Symbol}}) +function _is_observed(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) if isdynamic(cf) @@ -443,7 +458,7 @@ function SII.observed(nw::Network, snis) flatidxs[i] = _range[subsym_to_idx(sni.subidx, sym(cf))] elseif subsym_has_idx(sni.subidx, obssym(cf)) #found in observed _idx = subsym_to_idx(sni.subidx, obssym(cf)) - _obsf = _get_observed_f(nw, cf, sni.compidx) + _obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni)) obsfuns[i] = (u, aggbuf, p, t) -> _obsf(u, aggbuf, p, t)[_idx] else throw(ArgumentError("Cannot resolve observable $sni")) @@ -522,6 +537,7 @@ end #### Default values #### function SII.default_values(nw::Network) + aliased_changed(nw; warn=true) defs = Dict{SymbolicIndex{Int,Symbol},Float64}() for (ci, cf) in pairs(nw.im.vertexf) for s in psym(cf) diff --git a/src/utils.jl b/src/utils.jl index f2414128..7f38181e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -26,3 +26,58 @@ end @inline unrolled_foreach(f, t) = unrolled_foreach(f, nofilt, t) nofilt(_) = true + +""" + unique_mappings([f=identity], from, to) + +Given two vectors `from` and `to`, find all keys in `from` which exist only once. +For those unique keys, return a dict maping `from_unique => f(to)` +""" +unique_mappings(from, to) = unique_mappings(identity, from, to) +function unique_mappings(f, from, to) + counts = Dict{eltype(from),Int}() + for k in from + counts[k] = get(counts, k, 0) + 1 + end + unique = Dict{eltype(from),eltype(to)}() + for (k, v) in zip(from, to) + if get(counts, k, 0) == 1 + unique[k] = f(v) + end + end + unique +end + + +""" + hash_fields(obj::T, h) + +This is @generated helper function which unrolls all fields of a struct `obj` and +recursively hashes them. +""" +@generated function hash_fields(obj::T, h::UInt) where {T} + fields = fieldnames(obj) + subhashes = Expr(:block, (:(h = hash(obj.$field, h)) for field in fields)...) + + quote + h = hash(T, h) + $subhashes + h + end +end + +""" + equal_fields(a::T, b::T) where {T} + +Thise @generated helper function unrolls all fields of two structs `a` and `b` and +compares them. +""" +@generated function equal_fields(a::T, b::T) where {T} + fields = fieldnames(T) + subequals = Expr(:block, (:(a.$field == b.$field || return false) for field in fields)...) + + quote + $subequals + return true + end +end diff --git a/test/ComponentLibrary.jl b/test/ComponentLibrary.jl index 664cc514..1912e5a2 100644 --- a/test/ComponentLibrary.jl +++ b/test/ComponentLibrary.jl @@ -56,11 +56,11 @@ end Base.@propagate_inbounds function kuramoto_edge!(e, θ_s, θ_d, (K,), t) e .= K .* sin(θ_s[1] - θ_d[1]) end -function kuramoto_edge() - StaticEdge(f=kuramoto_edge!, +function kuramoto_edge(; name=:kuramoto_edge) + StaticEdge(;f=kuramoto_edge!, dim=1, sym=[:P], pdim=1, psym=[:K], - coupling=AntiSymmetric()) + coupling=AntiSymmetric(), name) end Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t) @@ -68,10 +68,10 @@ Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t) dv[1] = v[2] dv[2] = 1 / M * (Pm - D * v[2] + acc[1]) end -function kuramoto_second() - ODEVertex(f=kuramoto_inertia!, +function kuramoto_second(; name=:kuramoto_second) + ODEVertex(; f=kuramoto_inertia!, dim=2, sym=[:δ, :ω], def=[0, 0], - pdim=3, psym=[:M, :D, :Pm], pdef=[1, 0.1, 1]) + pdim=3, psym=[:M, :D, :Pm], pdef=[1, 0.1, 1], name) end Base.@propagate_inbounds function kuramoto_vertex!(dθ, θ, esum, (ω,), t) diff --git a/test/aggregators_test.jl b/test/aggregators_test.jl index 6780db3d..001fdee8 100644 --- a/test/aggregators_test.jl +++ b/test/aggregators_test.jl @@ -29,8 +29,8 @@ using StableRNGs rng = StableRNG(1) g = watts_strogatz(10_000, 4, 0.8; rng, is_directed=true) - nvec = rand(rng, vtypes, nv(g)) - evec = rand(rng, etypes, ne(g)) + nvec = copy.(rand(rng, vtypes, nv(g))) + evec = copy.(rand(rng, etypes, ne(g))) basenw = Network(g, nvec, evec); states = rand(rng, basenw.im.lastidx_static) diff --git a/test/construction_test.jl b/test/construction_test.jl index 881c670f..98421ad7 100644 --- a/test/construction_test.jl +++ b/test/construction_test.jl @@ -74,14 +74,17 @@ end Network([v1,v2,v3], [e1,e2,e3]) # throws waring about 1->2 and 2->1 beeing present v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) - v2 = ODEVertex(x->x^2, 2, 0; name=:v2) - v3 = ODEVertex(x->x^3, 2, 0; name=:v3) - @test NetworkDynamics._unique_name_dict([v1,v2,v3]) == Dict(:v1=>1) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) + nw = Network([v1,v2,v3], [e1,e2,e3]) + @test nw.im.unique_vnames == Dict(:v1=>1, :v2=>2, :v3=>3) - v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) - v2 = ODEVertex(x->x^2, 2, 0; name=:v1) - v3 = ODEVertex(x->x^3, 2, 0; name=:v3) - @test NetworkDynamics._unique_name_dict([v1,v2,v3]) == Dict() + v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v2) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) + set_graphelement!(e2, 3=>2) + nw = Network([v1,v2,v3], [e1,e2,e3]) + @test nw.im.unique_vnames == Dict(:v3=>3) end @testset "Vertex batch" begin using NetworkDynamics: BatchStride, VertexBatch, parameter_range @@ -266,3 +269,82 @@ end kwargs = Dict(:sym=>[:a=>2,:b],:def=>[1,nothing], :pdim=>0 ) @test_throws ArgumentError _fill_defaults(ODEVertex, kwargs) end + +@testset "test dealias and copy of components" begin + using NetworkDynamics: aliasgroups + v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) + + e1 = StaticEdge(nothing, 0, Symmetric(); graphelement=(;src=1,dst=2)) + e2 = StaticEdge(nothing, 0, Symmetric(); src=:v2, dst=:v3) + e3 = StaticEdge(nothing, 0, Symmetric(); src=:v3, dst=:v1) + + @test isempty(aliasgroups([v1,v2,v3])) + @test aliasgroups([v1, v1, v3]) == IdDict(v1 => [1,2]) + @test aliasgroups([v3, v1, v3]) == IdDict(v3 => [1,3]) + + g = complete_graph(3) + # with dealiasing / copying + nw = Network(g, [v1,v1,v3],[e1,e2,e3]; dealias=true, check_graphelement=false) + @test nw.im.vertexf[1].f == nw.im.vertexf[2].f + @test nw.im.vertexf[1].metadata == nw.im.vertexf[2].metadata + @test nw.im.vertexf[1].metadata !== nw.im.vertexf[2].metadata + @test nw.im.vertexf[1].symmetadata == nw.im.vertexf[2].symmetadata + @test nw.im.vertexf[1].symmetadata !== nw.im.vertexf[2].symmetadata + s0 = NWState(nw) + @test isnan(s0.v[1,1]) + set_default!(nw.im.vertexf[1], :v₁, 3) + s1 = NWState(nw) + @test s1.v[1,1] == 3 + + # witout dealisasing + nw = Network(g, [v1,v1,v3],[e1,e2,e1]; check_graphelement=false) + @test nw.im.vertexf[1] === nw.im.vertexf[2] + @test nw.im.edgef[1] === nw.im.edgef[3] + @test keys(nw.im.aliased_vertexfs) == Set([v1]) + @test keys(nw.im.aliased_edgefs) == Set([e1]) + @test only(unique(values(nw.im.aliased_vertexfs))) == (;idxs=[1,2], hash=hash(v1)) + @test only(unique(values(nw.im.aliased_edgefs))) == (;idxs=[1,3], hash=hash(e1)) + s0 = NWState(nw) + @test isnan(s0.v[1,1]) + prehash = hash(v1) + set_default!(v1, :v₁, 3) + posthash = hash(v1) + @test prehash !== posthash + @test NetworkDynamics.aliased_changed(nw; warn=false) + s1 = NWState(nw) + + # test copy + v = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), symmetadata=Dict(:x=>Dict(:default=>1))) + v2 = copy(v) + @test v !== v2 + @test v == v2 + @test isequal(v, v2) + @test get_default(v, :x) == get_default(v2, :x) + set_default!(v, :x, 99) + @test get_default(v, :x) == 99 + @test get_default(v2, :x) == 1 + @test v != v2 + @test v1 != v2 +end + +@testset "test network-remake constructor" begin + v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) + + e1 = StaticEdge(nothing, 0, Symmetric(); graphelement=(;src=1,dst=2)) + e2 = StaticEdge(nothing, 0, Symmetric(); src=:v1, dst=:v3) + e3 = StaticEdge(nothing, 0, Symmetric(); src=:v2, dst=:v3) + + g = complete_graph(3) + nw = Network(g, [v1,v2,v3],[e1,e2,e3]) + nw2 = Network(nw) + nw2 = Network(nw; g=path_graph(3), edgef=[e1, e3]) + + for aggT in subtypes(NetworkDynamics.Aggregator) + @show aggT + @test hasmethod(NetworkDynamics.get_aggr_constructor, (aggT,)) + end +end diff --git a/test/symbolicindexing_test.jl b/test/symbolicindexing_test.jl index 2ff75d7a..f1adb6d5 100644 --- a/test/symbolicindexing_test.jl +++ b/test/symbolicindexing_test.jl @@ -426,3 +426,30 @@ nw = Network(g, [n1, n2, n3], [e1, e2]) @test SII.get_all_timeseries_indexes(nw, VIndex(1,:u)) == Set([SII.ContinuousTimeseries()]) @test SII.get_all_timeseries_indexes(nw, VPIndex(1,:p1)) == Set([1]) @test SII.get_all_timeseries_indexes(nw, [VIndex(1,:u), VPIndex(1,:p1)]) == Set([SII.ContinuousTimeseries(), 1]) + +# test named vertices and edges +@testset "test sym indices for named edges/vertices" begin + v1 = Lib.kuramoto_second(name=:v1) + v2 = Lib.kuramoto_second(name=:v2) + v3 = Lib.kuramoto_second(name=:v3) + e1 = Lib.kuramoto_edge(name=:e1) + e2 = Lib.kuramoto_edge(name=:e2) + e3 = Lib.kuramoto_edge(name=:e3) + g = complete_graph(3) + nw = Network(g, [v1, v2, v3], [e1, e2, e3]) + s = NWState(nw, collect(1:dim(nw)), collect(dim(nw)+1:dim(nw)+pdim(nw))) + @test_throws ArgumentError s.v[:v, 1] + @test s.v[:v1, 1] == s[VIndex(1,1)] + @test s.v[:v2, 1] == s[VIndex(2,1)] + @test s.v[:v3, 1] == s[VIndex(3,1)] + @test s.e[:e1, 1] == s[EIndex(1,1)] + @test s.e[:e2, 1] == s[EIndex(2,1)] + @test s.e[:e3, 1] == s[EIndex(3,1)] + @test_throws ArgumentError s.p.v[:v, 1] + @test s.p.v[:v1, 1] == s[VPIndex(1,1)] + @test s.p.v[:v2, 1] == s[VPIndex(2,1)] + @test s.p.v[:v3, 1] == s[VPIndex(3,1)] + @test s.p.e[:e1, 1] == s[EPIndex(1,1)] + @test s.p.e[:e2, 1] == s[EPIndex(2,1)] + @test s.p.e[:e3, 1] == s[EPIndex(3,1)] +end diff --git a/test/utils_test.jl b/test/utils_test.jl index eaeb41dc..46c53312 100644 --- a/test/utils_test.jl +++ b/test/utils_test.jl @@ -80,4 +80,23 @@ using NetworkDynamics align_strings(["row &with\n&line break", "second &row"]) end + + @testset "unique_mappings" begin + using NetworkDynamics: unique_mappings + a = [1,2,3] + b = [1,2,3] + @test unique_mappings(a, b) == Dict(1=>1, 2=>2, 3=>3) + a = [:a, :b, :c] + b = [1,2,3] + @test unique_mappings(a, b) == Dict(:a=>1, :b=>2, :c=>3) + a = [:a, :b, :c] + b = [1,1,3] + @test unique_mappings(a, b) == Dict(:a=>1, :b=>1, :c=>3) + a = [:a, :a, :c] + b = [1,1,3] + @test unique_mappings(a, b) == Dict(:c=>3) + a = [:a, :a, :c] + b = [1,1,-3] + @test unique_mappings(abs, a, b) == Dict(:c=>3) + end end