From bdc8ece4084e70f1b864c26c0ae2eafeb1a35091 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Fri, 18 Oct 2024 09:24:25 +0200 Subject: [PATCH 01/17] move metadata related functions to new file --- src/NetworkDynamics.jl | 14 ++- src/component_functions.jl | 227 ------------------------------------- src/metadata.jl | 225 ++++++++++++++++++++++++++++++++++++ 3 files changed, 233 insertions(+), 233 deletions(-) create mode 100644 src/metadata.jl diff --git a/src/NetworkDynamics.jl b/src/NetworkDynamics.jl index 38e4c3e0..bc52cf56 100644 --- a/src/NetworkDynamics.jl +++ b/src/NetworkDynamics.jl @@ -39,12 +39,6 @@ export ODEVertex, StaticVertex, StaticEdge, ODEEdge export Symmetric, AntiSymmetric, Directed, Fiducial export dim, sym, pdim, psym, obssym, depth, hasinputsym, inputsym, coupling export metadata, symmetadata -export has_metadata, get_metadata, set_metadata! -export has_default, get_default, set_default! -export has_guess, get_guess, set_guess! -export has_init, get_init, set_init! -export has_bounds, get_bounds, set_bounds! -export has_graphelement, get_graphelement, set_graphelement! include("component_functions.jl") export Network @@ -64,6 +58,14 @@ export vidxs, eidxs, vpidxs, epidxs export save_parameters! include("symbolicindexing.jl") +export has_metadata, get_metadata, set_metadata! +export has_default, get_default, set_default! +export has_guess, get_guess, set_guess! +export has_init, get_init, set_init! +export has_bounds, get_bounds, set_bounds! +export has_graphelement, get_graphelement, set_graphelement! +include("metadata.jl") + using NonlinearSolve: AbstractNonlinearSolveAlgorithm, NonlinearFunction using NonlinearSolve: NonlinearLeastSquaresProblem, NonlinearProblem using SteadyStateDiffEq: SteadyStateProblem, SteadyStateDiffEqAlgorithm, SSRootfind diff --git a/src/component_functions.jl b/src/component_functions.jl index c9befd5d..4dafb8a0 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -584,230 +584,3 @@ _valid_signature(::Type{<:StaticEdge}, f) = _takes_n_vectors(f, 4) #(u, src, dst _valid_signature(::Type{<:ODEEdge}, f) = _takes_n_vectors(f, 5) #(du, u, src, dst, p, t) _takes_n_vectors(f, n) = hasmethod(f, (Tuple(Vector{Float64} for i in 1:n)..., Float64)) - - -#### -#### per sym metadata -#### -""" - has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) - -Checks if symbol metadata `key` is present for symbol `sym`. -""" -function has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) - md = symmetadata(c) - haskey(md, sym) && haskey(md[sym], key) -end -""" - get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) - -Retrievs the metadata `key` for symbol `sym`. -""" -get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) = symmetadata(c)[sym][key] - -""" - set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) - set_metadata!(c::ComponentFunction, sym::Symbol, pair) - -Sets the metadata `key` for symbol `sym` to `value`. -""" -function set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) - d = get!(symmetadata(c), sym, Dict{Symbol,Any}()) - d[key] = value -end -set_metadata!(c::ComponentFunction, sym::Symbol, pair::Pair) = set_metadata!(c, sym, pair.first, pair.second) - -#### default -""" - has_default(c::ComponentFunction, sym::Symbol) - -Checks if a `default` value is present for symbol `sym`. -""" -has_default(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :default) -""" - get_default(c::ComponentFunction, sym::Symbol) - -Returns the `default` value for symbol `sym`. -""" -get_default(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :default) -""" - set_default!(c::ComponentFunction, sym::Symbol, value) - -Sets the `default` value for symbol `sym` to `value`. -""" -set_default!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :default, value) - -#### guess -""" - has_guess(c::ComponentFunction, sym::Symbol) - -Checks if a `guess` value is present for symbol `sym`. -""" -has_guess(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :guess) -""" - get_guess(c::ComponentFunction, sym::Symbol) - -Returns the `guess` value for symbol `sym`. -""" -get_guess(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :guess) -""" - set_guess!(c::ComponentFunction, sym::Symbol, value) - -Sets the `guess` value for symbol `sym` to `value`. -""" -set_guess!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :guess, value) - -#### init -""" - has_init(c::ComponentFunction, sym::Symbol) - -Checks if a `init` value is present for symbol `sym`. -""" -has_init(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :init) -""" - get_init(c::ComponentFunction, sym::Symbol) - -Returns the `init` value for symbol `sym`. -""" -get_init(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :init) -""" - set_init!(c::ComponentFunction, sym::Symbol, value) - -Sets the `init` value for symbol `sym` to `value`. -""" -set_init!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :init, value) - -#### bounds -""" - has_bounds(c::ComponentFunction, sym::Symbol) - -Checks if a `bounds` value is present for symbol `sym`. -""" -has_bounds(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :bounds) -""" - get_bounds(c::ComponentFunction, sym::Symbol) - -Returns the `bounds` value for symbol `sym`. -""" -get_bounds(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :bounds) -""" - set_bounds!(c::ComponentFunction, sym::Symbol, value) - -Sets the `bounds` value for symbol `sym` to `value`. -""" -set_bounds!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :bounds, value) - - -#### default or init -""" - has_default_or_init(c::ComponentFunction, sym::Symbol) - -Checks if a `default` or `init` value is present for symbol `sym`. -""" -has_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_init(c, sym) -""" - get_default_or_init(c::ComponentFunction, sym::Symbol) - -Returns if a `default` value if available, otherwise returns `init` value for symbol `sym`. -""" -get_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_init(c, sym) - -#### default or guess -""" - has_default_or_guess(c::ComponentFunction, sym::Symbol) - -Checks if a `default` or `guess` value is present for symbol `sym`. -""" -has_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_guess(c, sym) -""" - get_default_or_guess(c::ComponentFunction, sym::Symbol) - -Returns if a `default` value if available, otherwise returns `guess` value for symbol `sym`. -""" -get_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_guess(c, sym) - - -# TODO: legacy, only used within show methods -function def(c::ComponentFunction)::Vector{Union{Nothing,Float64}} - map(c.sym) do s - has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing - end -end -function guess(c::ComponentFunction)::Vector{Union{Nothing,Float64}} - map(c.sym) do s - has_guess(c, s) ? get_guess(c, s) : nothing - end -end -function pdef(c::ComponentFunction)::Vector{Union{Nothing,Float64}} - map(c.psym) do s - has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing - end -end -function pguess(c::ComponentFunction)::Vector{Union{Nothing,Float64}} - map(c.psym) do s - has_guess(c, s) ? get_guess(c, s) : nothing - end -end - -#### -#### Component metadata -#### -""" - has_metadata(c::ComponentFunction, key::Symbol) - -Checks if metadata `key` is present for the component. -""" -function has_metadata(c::ComponentFunction, key) - haskey(metadata(c), key) -end -""" - get_metadata(c::ComponentFunction, key::Symbol) - -Retrieves the metadata `key` for the component. -""" -get_metadata(c::ComponentFunction, key::Symbol) = metadata(c)[key] -""" - set_metadata!(c::ComponentFunction, key::Symbol, value) - -Sets the metadata `key` for the component to `value`. -""" -set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), val, key) - -#### graphelement field for edges and vertices -""" - has_graphelement(c) - -Checks if the edge or vetex function function has the `graphelement` metadata. -""" -has_graphelement(c::ComponentFunction) = has_metadata(c, :graphelement) -""" - get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T} - get_graphelement(c::VertexFunction)::Int - -Retrieves the `graphelement` metadata for the component function. For edges this -returns a named tupe `(;src, dst)` where both are either integers (vertex index) -or symbols (vertex name). -""" -get_graphelement(c::EdgeFunction) = get_metadata(c, :graphelement)::@NamedTuple{src::T, dst::T} where {T<:Union{Int,Symbol}} -get_graphelement(c::VertexFunction) = get_metadata(c, :graphelement)::Int -""" - set_graphelement!(c::EdgeFunction, src, dst) - set_graphelement!(c::VertexFunction, vidx) - -Sets the `graphelement` metadata for the edge function. For edges this takes two -arguments `src` and `dst` which are either integer (vertex index) or symbol -(vertex name). For vertices it takes a single integer `vidx`. -""" -set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt) -set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx) - - -function get_defaults(c::ComponentFunction, syms) - [has_default(c, sym) ? get_default(c, sym) : nothing for sym in syms] -end -function get_guesses(c::ComponentFunction, syms) - [has_guess(c, sym) ? get_guess(c, sym) : nothing for sym in syms] -end -function get_defaults_or_inits(c::ComponentFunction, syms) - [has_default_or_init(c, sym) ? get_default_or_init(c, sym) : nothing for sym in syms] -end diff --git a/src/metadata.jl b/src/metadata.jl new file mode 100644 index 00000000..2d5e3d4e --- /dev/null +++ b/src/metadata.jl @@ -0,0 +1,225 @@ +#### +#### per sym metadata +#### +""" + has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) + +Checks if symbol metadata `key` is present for symbol `sym`. +""" +function has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) + md = symmetadata(c) + haskey(md, sym) && haskey(md[sym], key) +end +""" + get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) + +Retrievs the metadata `key` for symbol `sym`. +""" +get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) = symmetadata(c)[sym][key] + +""" + set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) + set_metadata!(c::ComponentFunction, sym::Symbol, pair) + +Sets the metadata `key` for symbol `sym` to `value`. +""" +function set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) + d = get!(symmetadata(c), sym, Dict{Symbol,Any}()) + d[key] = value +end +set_metadata!(c::ComponentFunction, sym::Symbol, pair::Pair) = set_metadata!(c, sym, pair.first, pair.second) + +#### default +""" + has_default(c::ComponentFunction, sym::Symbol) + +Checks if a `default` value is present for symbol `sym`. +""" +has_default(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :default) +""" + get_default(c::ComponentFunction, sym::Symbol) + +Returns the `default` value for symbol `sym`. +""" +get_default(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :default) +""" + set_default!(c::ComponentFunction, sym::Symbol, value) + +Sets the `default` value for symbol `sym` to `value`. +""" +set_default!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :default, value) + +#### guess +""" + has_guess(c::ComponentFunction, sym::Symbol) + +Checks if a `guess` value is present for symbol `sym`. +""" +has_guess(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :guess) +""" + get_guess(c::ComponentFunction, sym::Symbol) + +Returns the `guess` value for symbol `sym`. +""" +get_guess(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :guess) +""" + set_guess!(c::ComponentFunction, sym::Symbol, value) + +Sets the `guess` value for symbol `sym` to `value`. +""" +set_guess!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :guess, value) + +#### init +""" + has_init(c::ComponentFunction, sym::Symbol) + +Checks if a `init` value is present for symbol `sym`. +""" +has_init(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :init) +""" + get_init(c::ComponentFunction, sym::Symbol) + +Returns the `init` value for symbol `sym`. +""" +get_init(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :init) +""" + set_init!(c::ComponentFunction, sym::Symbol, value) + +Sets the `init` value for symbol `sym` to `value`. +""" +set_init!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :init, value) + +#### bounds +""" + has_bounds(c::ComponentFunction, sym::Symbol) + +Checks if a `bounds` value is present for symbol `sym`. +""" +has_bounds(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :bounds) +""" + get_bounds(c::ComponentFunction, sym::Symbol) + +Returns the `bounds` value for symbol `sym`. +""" +get_bounds(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :bounds) +""" + set_bounds!(c::ComponentFunction, sym::Symbol, value) + +Sets the `bounds` value for symbol `sym` to `value`. +""" +set_bounds!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :bounds, value) + + +#### default or init +""" + has_default_or_init(c::ComponentFunction, sym::Symbol) + +Checks if a `default` or `init` value is present for symbol `sym`. +""" +has_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_init(c, sym) +""" + get_default_or_init(c::ComponentFunction, sym::Symbol) + +Returns if a `default` value if available, otherwise returns `init` value for symbol `sym`. +""" +get_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_init(c, sym) + +#### default or guess +""" + has_default_or_guess(c::ComponentFunction, sym::Symbol) + +Checks if a `default` or `guess` value is present for symbol `sym`. +""" +has_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_guess(c, sym) +""" + get_default_or_guess(c::ComponentFunction, sym::Symbol) + +Returns if a `default` value if available, otherwise returns `guess` value for symbol `sym`. +""" +get_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_guess(c, sym) + + +# TODO: legacy, only used within show methods +function def(c::ComponentFunction)::Vector{Union{Nothing,Float64}} + map(c.sym) do s + has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing + end +end +function guess(c::ComponentFunction)::Vector{Union{Nothing,Float64}} + map(c.sym) do s + has_guess(c, s) ? get_guess(c, s) : nothing + end +end +function pdef(c::ComponentFunction)::Vector{Union{Nothing,Float64}} + map(c.psym) do s + has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing + end +end +function pguess(c::ComponentFunction)::Vector{Union{Nothing,Float64}} + map(c.psym) do s + has_guess(c, s) ? get_guess(c, s) : nothing + end +end + +#### +#### Component metadata +#### +""" + has_metadata(c::ComponentFunction, key::Symbol) + +Checks if metadata `key` is present for the component. +""" +function has_metadata(c::ComponentFunction, key) + haskey(metadata(c), key) +end +""" + get_metadata(c::ComponentFunction, key::Symbol) + +Retrieves the metadata `key` for the component. +""" +get_metadata(c::ComponentFunction, key::Symbol) = metadata(c)[key] +""" + set_metadata!(c::ComponentFunction, key::Symbol, value) + +Sets the metadata `key` for the component to `value`. +""" +set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), val, key) + +#### graphelement field for edges and vertices +""" + has_graphelement(c) + +Checks if the edge or vetex function function has the `graphelement` metadata. +""" +has_graphelement(c::ComponentFunction) = has_metadata(c, :graphelement) +""" + get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T} + get_graphelement(c::VertexFunction)::Int + +Retrieves the `graphelement` metadata for the component function. For edges this +returns a named tupe `(;src, dst)` where both are either integers (vertex index) +or symbols (vertex name). +""" +get_graphelement(c::EdgeFunction) = get_metadata(c, :graphelement)::@NamedTuple{src::T, dst::T} where {T<:Union{Int,Symbol}} +get_graphelement(c::VertexFunction) = get_metadata(c, :graphelement)::Int +""" + set_graphelement!(c::EdgeFunction, src, dst) + set_graphelement!(c::VertexFunction, vidx) + +Sets the `graphelement` metadata for the edge function. For edges this takes two +arguments `src` and `dst` which are either integer (vertex index) or symbol +(vertex name). For vertices it takes a single integer `vidx`. +""" +set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt) +set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx) + + +function get_defaults(c::ComponentFunction, syms) + [has_default(c, sym) ? get_default(c, sym) : nothing for sym in syms] +end +function get_guesses(c::ComponentFunction, syms) + [has_guess(c, sym) ? get_guess(c, sym) : nothing for sym in syms] +end +function get_defaults_or_inits(c::ComponentFunction, syms) + [has_default_or_init(c, sym) ? get_default_or_init(c, sym) : nothing for sym in syms] +end From 1560c13231170dba73f1cf0ead48ebf07f1b97a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Fri, 18 Oct 2024 09:39:30 +0200 Subject: [PATCH 02/17] generate default accessors using eval --- src/metadata.jl | 96 ++++++++++++++----------------------------------- 1 file changed, 26 insertions(+), 70 deletions(-) diff --git a/src/metadata.jl b/src/metadata.jl index 2d5e3d4e..9795dbb6 100644 --- a/src/metadata.jl +++ b/src/metadata.jl @@ -29,85 +29,41 @@ function set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value) end set_metadata!(c::ComponentFunction, sym::Symbol, pair::Pair) = set_metadata!(c, sym, pair.first, pair.second) -#### default -""" - has_default(c::ComponentFunction, sym::Symbol) - -Checks if a `default` value is present for symbol `sym`. -""" -has_default(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :default) -""" - get_default(c::ComponentFunction, sym::Symbol) - -Returns the `default` value for symbol `sym`. -""" -get_default(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :default) -""" - set_default!(c::ComponentFunction, sym::Symbol, value) - -Sets the `default` value for symbol `sym` to `value`. -""" -set_default!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :default, value) - -#### guess -""" - has_guess(c::ComponentFunction, sym::Symbol) - -Checks if a `guess` value is present for symbol `sym`. -""" -has_guess(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :guess) -""" - get_guess(c::ComponentFunction, sym::Symbol) - -Returns the `guess` value for symbol `sym`. -""" -get_guess(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :guess) -""" - set_guess!(c::ComponentFunction, sym::Symbol, value) +# generate default methods for some per-symbol metadata fields +for md in [:default, :guess, :init, :bounds] + fname_has = Symbol(:has_, md) + fname_get = Symbol(:get_, md) + fname_set = Symbol(:set_, md, :!) + @eval begin + """ + has_$($(QuoteNode(md)))(c::ComponentFunction, sym::Symbol) -Sets the `guess` value for symbol `sym` to `value`. -""" -set_guess!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :guess, value) + Checks if a `$($(QuoteNode(md)))` value is present for symbol `sym`. -#### init -""" - has_init(c::ComponentFunction, sym::Symbol) + See also [`get_$($(QuoteNode(md)))`](@ref), [`set_$($(QuoteNode(md)))`](@ref). + """ + $fname_has(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, $(QuoteNode(md))) -Checks if a `init` value is present for symbol `sym`. -""" -has_init(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :init) -""" - get_init(c::ComponentFunction, sym::Symbol) + """ + get_$($(QuoteNode(md)))(c::ComponentFunction, sym::Symbol) -Returns the `init` value for symbol `sym`. -""" -get_init(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :init) -""" - set_init!(c::ComponentFunction, sym::Symbol, value) + Returns the `$($(QuoteNode(md)))` value for symbol `sym`. -Sets the `init` value for symbol `sym` to `value`. -""" -set_init!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :init, value) + See also [`has_$($(QuoteNode(md)))`](@ref), [`set_$($(QuoteNode(md)))`](@ref). + """ + $fname_get(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, $(QuoteNode(md))) -#### bounds -""" - has_bounds(c::ComponentFunction, sym::Symbol) -Checks if a `bounds` value is present for symbol `sym`. -""" -has_bounds(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :bounds) -""" - get_bounds(c::ComponentFunction, sym::Symbol) + """ + set_$($(QuoteNode(md)))(c::ComponentFunction, sym::Symbol, value) -Returns the `bounds` value for symbol `sym`. -""" -get_bounds(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :bounds) -""" - set_bounds!(c::ComponentFunction, sym::Symbol, value) + Sets the `$($(QuoteNode(md)))` value for symbol `sym` to `value`. -Sets the `bounds` value for symbol `sym` to `value`. -""" -set_bounds!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :bounds, value) + See also [`has_$($(QuoteNode(md)))`](@ref), [`get_$($(QuoteNode(md)))`](@ref). + """ + $fname_set(c::ComponentFunction, sym::Symbol, val) = set_metadata!(c, sym, $(QuoteNode(md)), val) + end +end #### default or init From 37fbf62f97899ff80edcad7eb05abc720758557f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 21 Oct 2024 18:31:52 +0200 Subject: [PATCH 03/17] allow symbolic indices based on name of the component --- src/construction.jl | 69 ++++++++++++++++------------------- src/network_structure.jl | 6 ++- src/symbolicindexing.jl | 65 ++++++++++++++++++++------------- src/utils.jl | 21 +++++++++++ test/ComponentLibrary.jl | 12 +++--- test/symbolicindexing_test.jl | 28 ++++++++++++++ test/utils_test.jl | 19 ++++++++++ 7 files changed, 151 insertions(+), 69 deletions(-) diff --git a/src/construction.jl b/src/construction.jl index ef9ef630..a1e6ab81 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -6,6 +6,7 @@ function Network(g::AbstractGraph, vdepth=:auto, aggregator=execution isa SequentialExecution ? SequentialAggregator(+) : PolyesterAggregator(+), check_graphelement=true, + set_graphelement=false, verbose=false) reset_timer!() @timeit_debug "Construct Network" begin @@ -17,29 +18,7 @@ function Network(g::AbstractGraph, @argcheck length(_vertexf) == nv(g) @argcheck length(_edgef) == ne(g) - # check if graphelement is set correctly, warn otherwise - if check_graphelement - for (i, v) in pairs(_vertexf) - if has_graphelement(v) - if get_graphelement(v) != i - @warn "Vertex function $v has wrong `:graphelement` $(get_graphelement(v)) != $i. \ - Using this constructor the provided `:graphelement` is ignored!" - end - end - end - if any(has_graphelement, _edgef) - vnamedict = _unique_name_dict(_vertexf) - for (iteredge, ef) in zip(edges(g), _edgef) - if has_graphelement(ef) - ge = get_graphelement(ef) - if iteredge != _resolve_ge_to_edge(ge, vnamedict) - @warn "Edge function $ef has wrong `:graphelement` $(get_graphelement(ef)) != $iteredge. \ - Using this constructor the provided `:graphelement` is ignored!" - end - end - end - end - end + # check if components alias metadata and copy if necessary verbose && println("Create dynamic network with $(nv(g)) vertices and $(ne(g)) edges:") @@ -66,6 +45,33 @@ function Network(g::AbstractGraph, # create index manager im = IndexManager(g, dynstates, edepth, vdepth, _vertexf, _edgef) + + # check graph_element metadata and attach if necessary + for (i, vf) in pairs(_vertexf) + if check_graphelement && has_graphelement(vf) + if get_graphelement(vf) != i + @warn "Vertex function $(vf.name) is placed at node index $i bus has \ + `graphelement` $(get_graphelement(vf)) stored in metadata. \ + The wrong data will be " * (set_graphelement ? "overwritten!" : "ignored!") * + " Use `check_graphelement` and `set_graphelement` keywords to alter this behavior." + end + end + set_graphelement && set_graphelement!(vf, i) + end + for (iteredge, ef) in zip(im.edgevec, _edgef) + if check_graphelement && has_graphelement(ef) + ge = get_graphelement(ef) + src = get(im.unique_vnames, ge.src, ge.src) + dst = get(im.unique_vnames, ge.dst, ge.dst) + if iteredge.src != src || iteredge.dst != dst + @warn "Edge function $(ef.name) at $(iteredge.src) => $(iteredge.dst) has wrong `:graphelement` $src => $dst). \ + The wrong data will be " * (set_graphelement ? "overwritten!" : "ignored!") * + " Use `check_graphelement` and `set_graphelement` keywords to alter this behavior." + end + end + set_graphelement && set_graphelement!(ef, (;src=iteredge.src, dst=iteredge.dst)) + end + # batch identical edge and vertex functions @timeit_debug "batch identical vertexes" begin vidxs = _find_identical(vertexf, 1:nv(g)) @@ -143,7 +149,8 @@ function Network(vertexfs, edgefs; kwargs...) vdict = Dict(vidxs .=> vertexfs) - vnamedict = _unique_name_dict(vertexfs) + # find unique maapings from name => graphelement + vnamedict = unique_mappings(getproperty.(vertexfs, :name), get_graphelement.(vertexfs)) simpleedges = map(edgefs) do e ge = get_graphelement(e) @@ -169,21 +176,9 @@ function Network(vertexfs, edgefs; kwargs...) vfs_ordered = [vdict[k] for k in vertices(g)] efs_ordered = [edict[k] for k in edges(g)] - Network(g, vfs_ordered, efs_ordered; kwargs...) + Network(g, vfs_ordered, efs_ordered; check_graphelement=false, set_graphelement=false, kwargs...) end -function _unique_name_dict(cfs::AbstractVector{<:ComponentFunction}) - # find all names to resolve - names = getproperty.(cfs, :name) - dict = Dict(cf.name => get_graphelement(cf) for cf in cfs if has_graphelement(cf)) - # delete all names which occure multiple times - for i in eachindex(names) - if names[i] ∈ @views names[i+1:end] - delete!(dict, names[i]) - end - end - dict -end # resolve the graphelement ge (named tuple) to simple edge with potential lookup in vertex name dict dict function _resolve_ge_to_edge(ge, vnamedict) src = if ge.src isa Symbol diff --git a/src/network_structure.jl b/src/network_structure.jl index 2112444b..fe242b10 100644 --- a/src/network_structure.jl +++ b/src/network_structure.jl @@ -22,13 +22,17 @@ mutable struct IndexManager{G} lastidx_gbuf::Int vertexf::Vector{VertexFunction} edgef::Vector{EdgeFunction} + unique_vnames::Dict{Symbol,Int} + unique_enames::Dict{Symbol,Int} function IndexManager(g, dyn_states, edepth, vdepth, vertexf, edgef) new{typeof(g)}(g, collect(edges(g)), (Vector{UnitRange{Int}}(undef, nv(g)) for i in 1:3)..., (Vector{UnitRange{Int}}(undef, ne(g)) for i in 1:5)..., edepth, vdepth, 0, dyn_states, 0, 0, 0, - vertexf, edgef) + vertexf, edgef, + unique_mappings(getproperty.(vertexf, :name), 1:nv(g)), + unique_mappings(getproperty.(edgef, :name), 1:ne(g))) end end dim(im::IndexManager) = im.lastidx_dynamic diff --git a/src/symbolicindexing.jl b/src/symbolicindexing.jl index 328e16c6..94b28b36 100644 --- a/src/symbolicindexing.jl +++ b/src/symbolicindexing.jl @@ -91,20 +91,35 @@ SSI Maintainer assured that f.sys is really only used for symbolic indexig so me SciMLBase.__has_sys(nw::Network) = true Base.getproperty(nw::Network, s::Symbol) = s===:sys ? nw : getfield(nw, s) -SII.symbolic_type(::Type{<:SymbolicIndex{Int,<:Union{Symbol,Int}}}) = SII.ScalarSymbolic() +SII.symbolic_type(::Type{<:SymbolicIndex{<:Union{Symbol,Int},<:Union{Symbol,Int}}}) = SII.ScalarSymbolic() SII.symbolic_type(::Type{<:SymbolicIndex}) = SII.ArraySymbolic() SII.hasname(::SymbolicIndex) = false -SII.hasname(::SymbolicIndex{Int,<:Union{Symbol,Int}}) = true -SII.getname(x::SymbolicVertexIndex) = Symbol("v$(x.compidx)₊$(x.subidx)") -SII.getname(x::SymbolicEdgeIndex) = Symbol("e$(x.compidx)₊$(x.subidx)") +SII.hasname(::SymbolicIndex{<:Union{Symbol,Int},<:Union{Symbol,Int}}) = true +function SII.getname(x::SymbolicVertexIndex) + prefix = x.compidx isa Int ? :v : Symbol() + Symbol(prefix, Symbol(x.compidx), :₊, Symbol(x.subidx)) +end +function SII.getname(x::SymbolicEdgeIndex) + prefix = x.compidx isa Int ? :e : Symbol() + Symbol(prefix, Symbol(x.compidx), :₊, Symbol(x.subidx)) +end -getcomp(nw::Network, sni::Union{EIndex{Int},EPIndex{Int}}) = nw.im.edgef[sni.compidx] -getcomp(nw::Network, sni::Union{VIndex{Int},VPIndex{Int}}) = nw.im.vertexf[sni.compidx] -getcomprange(nw::Network, sni::VIndex{Int}) = nw.im.v_data[sni.compidx] -getcomprange(nw::Network, sni::EIndex{Int}) = nw.im.e_data[sni.compidx] -getcompprange(nw::Network, sni::VPIndex{Int}) = nw.im.v_para[sni.compidx] -getcompprange(nw::Network, sni::EPIndex{Int}) = nw.im.e_para[sni.compidx] +resolvecompidx(nw::Network, sni::SymbolicIndex{Int}) = sni.compidx +function resolvecompidx(nw::Network, sni::SymbolicIndex{Symbol}) + dict = sni isa SymbolicVertexIndex ? nw.im.unique_vnames : nw.im.unique_enames + if haskey(dict, sni.compidx) + return dict[sni.compidx] + else + throw(ArgumentError("Could not resolve component index for $sni, the name might not be unique?")) + end +end +getcomp(nw::Network, sni::SymbolicEdgeIndex) = nw.im.edgef[resolvecompidx(nw, sni)] +getcomp(nw::Network, sni::SymbolicVertexIndex) = nw.im.vertexf[resolvecompidx(nw, sni)] +getcomprange(nw::Network, sni::VIndex{<:Union{Symbol,Int}}) = nw.im.v_data[resolvecompidx(nw, sni)] +getcomprange(nw::Network, sni::EIndex{<:Union{Symbol,Int}}) = nw.im.e_data[resolvecompidx(nw, sni)] +getcompprange(nw::Network, sni::VPIndex{<:Union{Symbol,Int}}) = nw.im.v_para[resolvecompidx(nw, sni)] +getcompprange(nw::Network, sni::EPIndex{<:Union{Symbol,Int}}) = nw.im.e_para[resolvecompidx(nw, sni)] subsym_has_idx(sym::Symbol, syms) = sym ∈ syms subsym_has_idx(idx::Int, syms) = 1 ≤ idx ≤ length(syms) @@ -114,7 +129,8 @@ subsym_to_idx(idx::Int, _) = idx #### #### Iterator/Broadcast interface for ArraySymbolic types #### -Base.broadcastable(si::SymbolicIndex{<:Union{Int,Colon},<:Union{Int,Symbol,Colon}}) = Ref(si) +# TODO: not broadcasting over idx with colon is weird +Base.broadcastable(si::SymbolicIndex{<:Union{Int,Symbol,Colon},<:Union{Int,Symbol,Colon}}) = Ref(si) const _IterableComponent = SymbolicIndex{<:Union{AbstractVector,Tuple},<:Union{Int,Symbol}} Base.length(si::_IterableComponent) = length(si.compidx) @@ -137,7 +153,7 @@ function Base.iterate(si::_IterableComponent, state=nothing) _similar(si, it[1], si.subidx), it[2] end -const _IterableSubcomponent = SymbolicIndex{Int,<:Union{AbstractVector,Tuple}} +const _IterableSubcomponent = SymbolicIndex{<:Union{Symbol,Int},<:Union{AbstractVector,Tuple}} Base.length(si::_IterableSubcomponent) = length(si.subidx) Base.size(si::_IterableSubcomponent) = (length(si),) Base.IteratorSize(si::_IterableSubcomponent) = Base.HasShape{1}() @@ -170,10 +186,10 @@ _resolve_colon(nw::Network, sni::VIndex{Colon}) = VIndex(1:nv(nw), sni.subidx) _resolve_colon(nw::Network, sni::EIndex{Colon}) = EIndex(1:ne(nw), sni.subidx) _resolve_colon(nw::Network, sni::VPIndex{Colon}) = VPIndex(1:nv(nw), sni.subidx) _resolve_colon(nw::Network, sni::EPIndex{Colon}) = EPIndex(1:ne(nw), sni.subidx) -_resolve_colon(nw::Network, sni::VIndex{Int,Colon}) = VIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni))) -_resolve_colon(nw::Network, sni::EIndex{Int,Colon}) = EIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni))) -_resolve_colon(nw::Network, sni::VPIndex{Int,Colon}) = VPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni))) -_resolve_colon(nw::Network, sni::EPIndex{Int,Colon}) = EPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni))) +_resolve_colon(nw::Network, sni::VIndex{<:Union{Symbol,Int},Colon}) = VIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni))) +_resolve_colon(nw::Network, sni::EIndex{<:Union{Symbol,Int},Colon}) = EIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni))) +_resolve_colon(nw::Network, sni::VPIndex{<:Union{Symbol,Int},Colon}) = VPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni))) +_resolve_colon(nw::Network, sni::EPIndex{<:Union{Symbol,Int},Colon}) = EPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni))) #### Implmentation of index provider interface @@ -195,13 +211,13 @@ function SII.is_variable(nw::Network, sni) if _hascolon(sni) SII.is_variable(nw, _resolve_colon(nw,sni)) elseif SII.symbolic_type(sni) === SII.ArraySymbolic() - all(s -> SII.is_variable(nw, s), sni) + all(Base.Fix1(SII.is_variable, nw), sni) else _is_variable(nw, sni) end end _is_variable(nw::Network, sni) = false -function _is_variable(nw::Network, sni::SymbolicStateIndex{Int,<:Union{Int,Symbol}}) +function _is_variable(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) return isdynamic(cf) && subsym_has_idx(sni.subidx, sym(cf)) end @@ -215,7 +231,7 @@ function SII.variable_index(nw::Network, sni) _variable_index(nw, sni) end end -function _variable_index(nw::Network, sni::SymbolicStateIndex{Int,<:Union{Int,Symbol}}) +function _variable_index(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) range = getcomprange(nw, sni) range[subsym_to_idx(sni.subidx, sym(cf))] @@ -242,14 +258,14 @@ function SII.is_parameter(nw::Network, sni) if _hascolon(sni) SII.is_parameter(nw, _resolve_colon(nw,sni)) elseif SII.symbolic_type(sni) === SII.ArraySymbolic() - all(s -> SII.is_parameter(nw, s), sni) + all(Base.Fix1(SII.is_parameter, nw), sni) else _is_parameter(nw, sni) end end _is_parameter(nw::Network, sni) = false function _is_parameter(nw::Network, - sni::SymbolicParameterIndex{Int,<:Union{Int,Symbol}}) + sni::SymbolicParameterIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) return subsym_has_idx(sni.subidx, psym(cf)) end @@ -264,7 +280,7 @@ function SII.parameter_index(nw::Network, sni) end end function _parameter_index(nw::Network, - sni::SymbolicParameterIndex{Int,<:Union{Int,Symbol}}) + sni::SymbolicParameterIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) range = getcompprange(nw, sni) range[subsym_to_idx(sni.subidx, psym(cf))] @@ -272,7 +288,6 @@ end function SII.parameter_symbols(nw::Network) syms = Vector{SymbolicParameterIndex{Int,Symbol}}(undef, pdim(nw)) - i = 1 for (ci, cf) in pairs(nw.im.vertexf) syms[nw.im.v_para[ci]] .= VPIndex.(ci, psym(cf)) end @@ -362,7 +377,7 @@ function SII.is_observed(nw::Network, sni) end end _is_observed(nw::Network, _) = false -function _is_observed(nw::Network, sni::SymbolicStateIndex{Int,<:Union{Int,Symbol}}) +function _is_observed(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}) cf = getcomp(nw, sni) if isdynamic(cf) @@ -443,7 +458,7 @@ function SII.observed(nw::Network, snis) flatidxs[i] = _range[subsym_to_idx(sni.subidx, sym(cf))] elseif subsym_has_idx(sni.subidx, obssym(cf)) #found in observed _idx = subsym_to_idx(sni.subidx, obssym(cf)) - _obsf = _get_observed_f(nw, cf, sni.compidx) + _obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni)) obsfuns[i] = (u, aggbuf, p, t) -> _obsf(u, aggbuf, p, t)[_idx] else throw(ArgumentError("Cannot resolve observable $sni")) diff --git a/src/utils.jl b/src/utils.jl index f2414128..ba5f0ad4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -26,3 +26,24 @@ end @inline unrolled_foreach(f, t) = unrolled_foreach(f, nofilt, t) nofilt(_) = true + +""" + unique_mappings([f=identity], from, to) + +Given two vectors `from` and `to`, find all keys in `from` which exist only once. +For those unique keys, return a dict maping `from_unique => f(to)` +""" +unique_mappings(from, to) = unique_mappings(identity, from, to) +function unique_mappings(f, from, to) + counts = Dict{eltype(from),Int}() + for k in from + counts[k] = get(counts, k, 0) + 1 + end + unique = Dict{eltype(from),eltype(to)}() + for (k, v) in zip(from, to) + if get(counts, k, 0) == 1 + unique[k] = f(v) + end + end + unique +end diff --git a/test/ComponentLibrary.jl b/test/ComponentLibrary.jl index 664cc514..1912e5a2 100644 --- a/test/ComponentLibrary.jl +++ b/test/ComponentLibrary.jl @@ -56,11 +56,11 @@ end Base.@propagate_inbounds function kuramoto_edge!(e, θ_s, θ_d, (K,), t) e .= K .* sin(θ_s[1] - θ_d[1]) end -function kuramoto_edge() - StaticEdge(f=kuramoto_edge!, +function kuramoto_edge(; name=:kuramoto_edge) + StaticEdge(;f=kuramoto_edge!, dim=1, sym=[:P], pdim=1, psym=[:K], - coupling=AntiSymmetric()) + coupling=AntiSymmetric(), name) end Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t) @@ -68,10 +68,10 @@ Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t) dv[1] = v[2] dv[2] = 1 / M * (Pm - D * v[2] + acc[1]) end -function kuramoto_second() - ODEVertex(f=kuramoto_inertia!, +function kuramoto_second(; name=:kuramoto_second) + ODEVertex(; f=kuramoto_inertia!, dim=2, sym=[:δ, :ω], def=[0, 0], - pdim=3, psym=[:M, :D, :Pm], pdef=[1, 0.1, 1]) + pdim=3, psym=[:M, :D, :Pm], pdef=[1, 0.1, 1], name) end Base.@propagate_inbounds function kuramoto_vertex!(dθ, θ, esum, (ω,), t) diff --git a/test/symbolicindexing_test.jl b/test/symbolicindexing_test.jl index 2ff75d7a..609dade1 100644 --- a/test/symbolicindexing_test.jl +++ b/test/symbolicindexing_test.jl @@ -426,3 +426,31 @@ nw = Network(g, [n1, n2, n3], [e1, e2]) @test SII.get_all_timeseries_indexes(nw, VIndex(1,:u)) == Set([SII.ContinuousTimeseries()]) @test SII.get_all_timeseries_indexes(nw, VPIndex(1,:p1)) == Set([1]) @test SII.get_all_timeseries_indexes(nw, [VIndex(1,:u), VPIndex(1,:p1)]) == Set([SII.ContinuousTimeseries(), 1]) + +# test named vertices and edges +@testset "test sym indices for named edges/vertices" begin + using ModelingToolkit: @named + @named v1 = Lib.kuramoto_second() + @named v2 = Lib.kuramoto_second() + @named v3 = Lib.kuramoto_second() + @named e1 = Lib.kuramoto_edge() + @named e2 = Lib.kuramoto_edge() + @named e3 = Lib.kuramoto_edge() + g = complete_graph(3) + nw = Network(g, [v1, v2, v3], [e1, e2, e3]) + s = NWState(nw, collect(1:dim(nw)), collect(dim(nw)+1:dim(nw)+pdim(nw))) + @test_throws ArgumentError s.v[:v, 1] + @test s.v[:v1, 1] == s[VIndex(1,1)] + @test s.v[:v2, 1] == s[VIndex(2,1)] + @test s.v[:v3, 1] == s[VIndex(3,1)] + @test s.e[:e1, 1] == s[EIndex(1,1)] + @test s.e[:e2, 1] == s[EIndex(2,1)] + @test s.e[:e3, 1] == s[EIndex(3,1)] + @test_throws ArgumentError s.p.v[:v, 1] + @test s.p.v[:v1, 1] == s[VPIndex(1,1)] + @test s.p.v[:v2, 1] == s[VPIndex(2,1)] + @test s.p.v[:v3, 1] == s[VPIndex(3,1)] + @test s.p.e[:e1, 1] == s[EPIndex(1,1)] + @test s.p.e[:e2, 1] == s[EPIndex(2,1)] + @test s.p.e[:e3, 1] == s[EPIndex(3,1)] +end diff --git a/test/utils_test.jl b/test/utils_test.jl index eaeb41dc..46c53312 100644 --- a/test/utils_test.jl +++ b/test/utils_test.jl @@ -80,4 +80,23 @@ using NetworkDynamics align_strings(["row &with\n&line break", "second &row"]) end + + @testset "unique_mappings" begin + using NetworkDynamics: unique_mappings + a = [1,2,3] + b = [1,2,3] + @test unique_mappings(a, b) == Dict(1=>1, 2=>2, 3=>3) + a = [:a, :b, :c] + b = [1,2,3] + @test unique_mappings(a, b) == Dict(:a=>1, :b=>2, :c=>3) + a = [:a, :b, :c] + b = [1,1,3] + @test unique_mappings(a, b) == Dict(:a=>1, :b=>1, :c=>3) + a = [:a, :a, :c] + b = [1,1,3] + @test unique_mappings(a, b) == Dict(:c=>3) + a = [:a, :a, :c] + b = [1,1,-3] + @test unique_mappings(abs, a, b) == Dict(:c=>3) + end end From 25cdb6f6133f01251c8a91922a7414fa0c5d4369 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 21 Oct 2024 18:34:24 +0200 Subject: [PATCH 04/17] aliaschecks for components If components share the same metadata objects (for example because of `vertices = [v1, v1, v2]`), create copys to dealias! Also, for homogenous networks all components are copied on NW construction. --- src/component_functions.jl | 18 +++++++++++++ src/construction.jl | 54 ++++++++++++++++++++++++++++++++++++-- test/construction_test.jl | 18 +++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/component_functions.jl b/src/component_functions.jl index 4dafb8a0..896bed20 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -584,3 +584,21 @@ _valid_signature(::Type{<:StaticEdge}, f) = _takes_n_vectors(f, 4) #(u, src, dst _valid_signature(::Type{<:ODEEdge}, f) = _takes_n_vectors(f, 5) #(du, u, src, dst, p, t) _takes_n_vectors(f, n) = hasmethod(f, (Tuple(Vector{Float64} for i in 1:n)..., Float64)) + +""" + copy(c::NetworkDynamics.ComponentFunction) + +Shallow copy of the component function. Creates a deepcopy of `metadata` and `symmetadata` +but references the same objects everywhere else. +""" +function Base.copy(c::ComponentFunction) + T = typeof(c) + args = map(fieldnames(T)) do fn + if fn ∈ (:metadata, :symmetadata) + deepcopy(getproperty(c, fn)) + else + getproperty(c, fn) + end + end + return T(args...) +end diff --git a/src/construction.jl b/src/construction.jl index a1e6ab81..f8064bf8 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -11,14 +11,16 @@ function Network(g::AbstractGraph, reset_timer!() @timeit_debug "Construct Network" begin # collect all vertex/edgf to vector - _vertexf = vertexf isa Vector ? vertexf : [vertexf for _ in vertices(g)] - _edgef = edgef isa Vector ? edgef : [edgef for _ in edges(g)] + _vertexf = vertexf isa Vector ? vertexf : [copy(vertexf) for _ in vertices(g)] + _edgef = edgef isa Vector ? edgef : [copy(edgef) for _ in edges(g)] @argcheck _vertexf isa Vector{<:VertexFunction} "Expected VertexFuncions, got $(eltype(_vertexf))" @argcheck _edgef isa Vector{<:EdgeFunction} "Expected EdgeFuncions, got $(eltype(_vertexf))" @argcheck length(_vertexf) == nv(g) @argcheck length(_edgef) == ne(g) # check if components alias metadata and copy if necessary + _dealias!(_vertexf) + _dealias!(_edgef) verbose && println("Create dynamic network with $(nv(g)) vertices and $(ne(g)) edges:") @@ -179,6 +181,54 @@ function Network(vertexfs, edgefs; kwargs...) Network(g, vfs_ordered, efs_ordered; check_graphelement=false, set_graphelement=false, kwargs...) end +""" + _dealisas!(cfs::Vector{<:ComponentFunction}) + +Checks if any component functions reference the same metadtata/symmetada fields and +creates copies of them if necessary. +""" +function _dealias!(cfs::Vector{<:ComponentFunction}) + smd_dict = IdDict{Dict{Symbol,Dict{Symbol, Any}},Vector{Int}}() + md_dict = IdDict{Dict{Symbol,Any},Vector{Int}}() + + needscopy = false + for (i, cf) in pairs(cfs) + if haskey(md_dict, metadata(cf)) + needscopy = true + push!(md_dict[metadata(cf)], i) + else + md_dict[metadata(cf)] = [i] + end + if haskey(smd_dict, symmetadata(cf)) + needscopy = true + push!(smd_dict[symmetadata(cf)], i) + else + smd_dict[symmetadata(cf)] = [i] + end + end + + if !needscopy + return cfs + end + + copyidxs = Int[] + for v in values(smd_dict) + length(v) > 1 && append!(copyidxs, v) + end + for v in values(md_dict) + length(v) > 1 && append!(copyidxs, v) + end + unique!(copyidxs) + + comp = first(cfs) isa VertexFunction ? "Vertices" : "Edges" + + @warn "$comp $copyidxs reference the same metadata and will be copied. This can happen if \ + the same component reference multiple times. Manually `copy` the component functions to \ + avoid this warning." + + cfs[copyidxs] = copy.(cfs[copyidxs]) +end + # resolve the graphelement ge (named tuple) to simple edge with potential lookup in vertex name dict dict function _resolve_ge_to_edge(ge, vnamedict) src = if ge.src isa Symbol diff --git a/test/construction_test.jl b/test/construction_test.jl index 881c670f..b061bb5a 100644 --- a/test/construction_test.jl +++ b/test/construction_test.jl @@ -266,3 +266,21 @@ end kwargs = Dict(:sym=>[:a=>2,:b],:def=>[1,nothing], :pdim=>0 ) @test_throws ArgumentError _fill_defaults(ODEVertex, kwargs) end + +@testset "test dealias and copy of components" begin + v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) + + e1 = StaticEdge(nothing, 0, Symmetric(); graphelement=(;src=1,dst=2)) + e2 = StaticEdge(nothing, 0, Symmetric(); src=:v2, dst=:v3) + e3 = StaticEdge(nothing, 0, Symmetric(); src=:v3, dst=:v1) + + g = complete_graph(3) + nw = Network(g, [v1,v1,v3],[e1,e2,e3]) + @test nw.im.vertexf[1].f == nw.im.vertexf[2].f + @test nw.im.vertexf[1].metadata == nw.im.vertexf[2].metadata + @test nw.im.vertexf[1].metadata !== nw.im.vertexf[2].metadata + @test nw.im.vertexf[1].symmetadata == nw.im.vertexf[2].symmetadata + @test nw.im.vertexf[1].symmetadata !== nw.im.vertexf[2].symmetadata +end From 08a2ace621518b3a7eca6463cb15b1d3e35a2bdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 21 Oct 2024 18:36:12 +0200 Subject: [PATCH 05/17] add restructuring constructor for component functions can be used to change single things about them, like the name for example --- src/component_functions.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/component_functions.jl b/src/component_functions.jl index 896bed20..52a53829 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -190,6 +190,7 @@ ODEVertex(; kwargs...) = _construct_comp(ODEVertex, kwargs) ODEVertex(f; kwargs...) = ODEVertex(;f, kwargs...) ODEVertex(f, dim; kwargs...) = ODEVertex(;f, _dimsym(dim)..., kwargs...) ODEVertex(f, dim, pdim; kwargs...) = ODEVertex(;f, _dimsym(dim, pdim)..., kwargs...) +ODEVertex(v::ODEVertex; kwargs...) = _reconstruct_comp(ODEVertex, v, kwargs) struct StaticVertex{F,OF} <: VertexFunction @CommonFields @@ -198,6 +199,7 @@ StaticVertex(; kwargs...) = _construct_comp(StaticVertex, kwargs) StaticVertex(f; kwargs...) = StaticVertex(;f, kwargs...) StaticVertex(f, dim; kwargs...) = StaticVertex(;f, _dimsym(dim)..., kwargs...) StaticVertex(f, dim, pdim; kwargs...) = StaticVertex(;f, _dimsym(dim, pdim)..., kwargs...) +StaticVertex(v::StaticVertex; kwargs...) = _reconstruct_comp(StaticVertex, v, kwargs) function ODEVertex(sv::StaticVertex) d = Dict{Symbol,Any}() for prop in propertynames(sv) @@ -224,6 +226,7 @@ StaticEdge(; kwargs...) = _construct_comp(StaticEdge, kwargs) StaticEdge(f; kwargs...) = StaticEdge(;f, kwargs...) StaticEdge(f, dim, coupling; kwargs...) = StaticEdge(;f, _dimsym(dim)..., coupling, kwargs...) StaticEdge(f, dim, pdim, coupling; kwargs...) = StaticEdge(;f, _dimsym(dim, pdim)..., coupling, kwargs...) +StaticEdge(e::StaticEdge; kwargs...) = _reconstruct_comp(StaticEdge, e, kwargs) struct ODEEdge{C,F,OF,MM} <: EdgeFunction{C} @CommonFields @@ -234,6 +237,7 @@ ODEEdge(; kwargs...) = _construct_comp(ODEEdge, kwargs) ODEEdge(f; kwargs...) = ODEEdge(;f, kwargs...) ODEEdge(f, dim, coupling; kwargs...) = ODEEdge(;f, _dimsym(dim)..., coupling, kwargs...) ODEEdge(f, dim, pdim, coupling; kwargs...) = ODEEdge(;f, _dimsym(dim, pdim)..., coupling, kwargs...) +ODEEdge(e::ODEEdge; kwargs...) = _reconstruct_comp(ODEEdge, e, kwargs) statetype(::T) where {T<:ComponentFunction} = statetype(T) statetype(::Type{<:ODEVertex}) = Dynamic() @@ -313,6 +317,18 @@ function _construct_comp(::Type{T}, kwargs) where {T} return c end +function _reconstruct_comp(::Type{T}, cf::ComponentFunction, kwargs) where {T} + fields = fieldnames(T) + dict = Dict{Symbol, Any}() + for f in fields + dict[f] = getproperty(cf, f) + end + for (k, v) in kwargs + dict[k] = v + end + _construct_comp(T, dict) +end + """ _fill_defaults(T, kwargs) From ecd77d91edb88a1311ccfe6f3ea3f4412fecb2ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 21 Oct 2024 19:10:25 +0200 Subject: [PATCH 06/17] fix remake constructor for network --- src/aggregators.jl | 8 ++++++++ src/construction.jl | 28 ++++++++++++++++++++++++---- test/construction_test.jl | 20 ++++++++++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/aggregators.jl b/src/aggregators.jl index bd31e809..51911374 100644 --- a/src/aggregators.jl +++ b/src/aggregators.jl @@ -315,6 +315,14 @@ function aggregate!(a::SparseAggregator, aggbuf, data) nothing end +# functions to retrieve the constructor of an aggregator for remake of network +get_aggr_constructor(a::NaiveAggregator) = NaiveAggregator(a.f) +get_aggr_constructor(a::KAAggregator) = KAAggregator(a.f) +get_aggr_constructor(a::SequentialAggregator) = SequentialAggregator(a.f) +get_aggr_constructor(a::PolyesterAggregator) = PolyesterAggregator(a.f) +get_aggr_constructor(a::ThreadedAggregator) = ThreadedAggregator(a.f) +get_aggr_constructor(a::SparseAggregator) = SparseAggregator(+) + iscudacompatible(::Type{<:Aggregator}) = false iscudacompatible(::Type{<:KAAggregator}) = true iscudacompatible(::Type{<:SparseAggregator}) = true diff --git a/src/construction.jl b/src/construction.jl index f8064bf8..5b5c7555 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -385,11 +385,31 @@ function _check_massmatrix(c) end """ - Network(nw::Network; kwargs...) + Network(nw::Network; g, vertexf, edgef, kwargs...) Rebuild the Network with same graph and vertex/edge functions but possibly different kwargs. -# FIXME : needs to take all Network kw arguments into acount! """ -function Network(nw::Network; kwargs...) - Network(nw.im.g, nw.im.vertexf, nw.im.edgef; kwargs...) +function Network(nw::Network; + g = nw.im.g, + vertexf = copy.(nw.im.vertexf), + edgef = copy.(nw.im.edgef), + kwargs...) + + _kwargs = Dict(:execution => executionstyle(nw), + :edepth => :auto, + :vdepth => :auto, + :aggregator => get_aggr_constructor(nw.layer.aggregator), + :check_graphelement => true, + :set_graphelement => false, + :verbose => false) + for (k, v) in kwargs + _kwargs[k] = v + end + + # check, that we actually provide all of the arguments + # mainly so we don't forget to add it here if we introduce new kw arg to main constructor + m = only(methods(Network, [typeof(g), typeof(vertexf), typeof(edgef)])) + @assert keys(_kwargs) == Set(Base.kwarg_decl(m)) + + Network(g, vertexf, edgef; _kwargs...) end diff --git a/test/construction_test.jl b/test/construction_test.jl index b061bb5a..c077f2ac 100644 --- a/test/construction_test.jl +++ b/test/construction_test.jl @@ -284,3 +284,23 @@ end @test nw.im.vertexf[1].symmetadata == nw.im.vertexf[2].symmetadata @test nw.im.vertexf[1].symmetadata !== nw.im.vertexf[2].symmetadata end + +@testset "test network-remake constructor" begin + v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) + + e1 = StaticEdge(nothing, 0, Symmetric(); graphelement=(;src=1,dst=2)) + e2 = StaticEdge(nothing, 0, Symmetric(); src=:v1, dst=:v3) + e3 = StaticEdge(nothing, 0, Symmetric(); src=:v2, dst=:v3) + + g = complete_graph(3) + nw = Network(g, [v1,v2,v3],[e1,e2,e3]) + nw2 = Network(nw) + nw2 = Network(nw; g=path_graph(3), edgef=[e1, e3]) + + for aggT in subtypes(NetworkDynamics.Aggregator) + @show aggT + @test hasmethod(NetworkDynamics.get_aggr_constructor, (aggT,)) + end +end From 7cb44a88334f014b7787ad1cee1bcfbdae84549a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 21 Oct 2024 19:23:05 +0200 Subject: [PATCH 07/17] add conv method for graphelement from pair --- src/metadata.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/metadata.jl b/src/metadata.jl index 9795dbb6..5d559db9 100644 --- a/src/metadata.jl +++ b/src/metadata.jl @@ -64,6 +64,7 @@ for md in [:default, :guess, :init, :bounds] $fname_set(c::ComponentFunction, sym::Symbol, val) = set_metadata!(c, sym, $(QuoteNode(md)), val) end end +set_graphelement!(c::EdgeFunction, p::Pair) = set_graphelement!(c, (;src=p.first, dst=p.second)) #### default or init From 15c4ca1e23e2c8b759c5ed79945ffc914d3d1918 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 21 Oct 2024 19:23:38 +0200 Subject: [PATCH 08/17] fix tests --- test/aggregators_test.jl | 4 ++-- test/construction_test.jl | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/aggregators_test.jl b/test/aggregators_test.jl index 6780db3d..001fdee8 100644 --- a/test/aggregators_test.jl +++ b/test/aggregators_test.jl @@ -29,8 +29,8 @@ using StableRNGs rng = StableRNG(1) g = watts_strogatz(10_000, 4, 0.8; rng, is_directed=true) - nvec = rand(rng, vtypes, nv(g)) - evec = rand(rng, etypes, ne(g)) + nvec = copy.(rand(rng, vtypes, nv(g))) + evec = copy.(rand(rng, etypes, ne(g))) basenw = Network(g, nvec, evec); states = rand(rng, basenw.im.lastidx_static) diff --git a/test/construction_test.jl b/test/construction_test.jl index c077f2ac..c76e8ad5 100644 --- a/test/construction_test.jl +++ b/test/construction_test.jl @@ -74,14 +74,17 @@ end Network([v1,v2,v3], [e1,e2,e3]) # throws waring about 1->2 and 2->1 beeing present v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) - v2 = ODEVertex(x->x^2, 2, 0; name=:v2) - v3 = ODEVertex(x->x^3, 2, 0; name=:v3) - @test NetworkDynamics._unique_name_dict([v1,v2,v3]) == Dict(:v1=>1) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) + nw = Network([v1,v2,v3], [e1,e2,e3]) + @test nw.im.unique_vnames == Dict(:v1=>1, :v2=>2, :v3=>3) - v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) - v2 = ODEVertex(x->x^2, 2, 0; name=:v1) - v3 = ODEVertex(x->x^3, 2, 0; name=:v3) - @test NetworkDynamics._unique_name_dict([v1,v2,v3]) == Dict() + v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v2) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) + set_graphelement!(e2, 3=>2) + nw = Network([v1,v2,v3], [e1,e2,e3]) + @test nw.im.unique_vnames == Dict(:v3=>3) end @testset "Vertex batch" begin using NetworkDynamics: BatchStride, VertexBatch, parameter_range From b324cf9ad2b4fde9c371fd21742b4e90674a7a15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Tue, 22 Oct 2024 07:10:53 +0200 Subject: [PATCH 09/17] remove named --- test/symbolicindexing_test.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/test/symbolicindexing_test.jl b/test/symbolicindexing_test.jl index 609dade1..f1adb6d5 100644 --- a/test/symbolicindexing_test.jl +++ b/test/symbolicindexing_test.jl @@ -429,13 +429,12 @@ nw = Network(g, [n1, n2, n3], [e1, e2]) # test named vertices and edges @testset "test sym indices for named edges/vertices" begin - using ModelingToolkit: @named - @named v1 = Lib.kuramoto_second() - @named v2 = Lib.kuramoto_second() - @named v3 = Lib.kuramoto_second() - @named e1 = Lib.kuramoto_edge() - @named e2 = Lib.kuramoto_edge() - @named e3 = Lib.kuramoto_edge() + v1 = Lib.kuramoto_second(name=:v1) + v2 = Lib.kuramoto_second(name=:v2) + v3 = Lib.kuramoto_second(name=:v3) + e1 = Lib.kuramoto_edge(name=:e1) + e2 = Lib.kuramoto_edge(name=:e2) + e3 = Lib.kuramoto_edge(name=:e3) g = complete_graph(3) nw = Network(g, [v1, v2, v3], [e1, e2, e3]) s = NWState(nw, collect(1:dim(nw)), collect(dim(nw)+1:dim(nw)+pdim(nw))) From f351a80597965083d0fd45a534285f1e50c4aad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Tue, 22 Oct 2024 15:09:56 +0200 Subject: [PATCH 10/17] use generated function to speed up copy --- src/component_functions.jl | 23 ++++++++++++++--------- test/construction_test.jl | 6 ++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/component_functions.jl b/src/component_functions.jl index 52a53829..2b3b43cf 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -607,14 +607,19 @@ _takes_n_vectors(f, n) = hasmethod(f, (Tuple(Vector{Float64} for i in 1:n)..., F Shallow copy of the component function. Creates a deepcopy of `metadata` and `symmetadata` but references the same objects everywhere else. """ -function Base.copy(c::ComponentFunction) - T = typeof(c) - args = map(fieldnames(T)) do fn - if fn ∈ (:metadata, :symmetadata) - deepcopy(getproperty(c, fn)) - else - getproperty(c, fn) - end +@generated function Base.copy(c::ComponentFunction) + fields = fieldnames(c) + # fields to copy + cfields = (:metadata, :symmetadata) + # normal fields + nfields = setdiff(fields, cfields) + assign = Expr(:block, + (:($(field) = c.$field) for field in nfields)..., + (:($(field) = deepcopy(c.$field)) for field in cfields)...) + construct = Expr(:call, c, [:($field) for field in fields]...) + + quote + $assign + $construct end - return T(args...) end diff --git a/test/construction_test.jl b/test/construction_test.jl index c76e8ad5..07ca0f57 100644 --- a/test/construction_test.jl +++ b/test/construction_test.jl @@ -286,6 +286,12 @@ end @test nw.im.vertexf[1].metadata !== nw.im.vertexf[2].metadata @test nw.im.vertexf[1].symmetadata == nw.im.vertexf[2].symmetadata @test nw.im.vertexf[1].symmetadata !== nw.im.vertexf[2].symmetadata + + v = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), symmetadata=Dict(:x=>Dict(:default=>1))) + v2 = copy(v) + @test get_default(v, :x) == get_default(v2, :x) + set_default!(v, :x, 99) + @test get_default(v2, :x) == 1 end @testset "test network-remake constructor" begin From 4cd39680a27a640cddc9bd38557f325d0a89b629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 23 Oct 2024 11:50:12 +0200 Subject: [PATCH 11/17] remove _resolve_to_edge function --- src/construction.jl | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/construction.jl b/src/construction.jl index 5b5c7555..e62f25e7 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -156,7 +156,12 @@ function Network(vertexfs, edgefs; kwargs...) simpleedges = map(edgefs) do e ge = get_graphelement(e) - _resolve_ge_to_edge(ge, vnamedict) + src = get(vnamedict, ge.src, ge.src) + dst = get(vnamedict, ge.dst, ge.dst) + if src isa Symbol || dst isa Symbol + throw(ArgumentError("Edge graphelement $src => $dst continas non-unique or unknown vertex names!")) + end + SimpleEdge(src, dst) end allunique(simpleedges) || throw(ArgumentError("Some edge functions have the same `graphelement`!")) edict = Dict(simpleedges .=> edgefs) @@ -229,22 +234,6 @@ function _dealias!(cfs::Vector{<:ComponentFunction}) cfs[copyidxs] = copy.(cfs[copyidxs]) end -# resolve the graphelement ge (named tuple) to simple edge with potential lookup in vertex name dict dict -function _resolve_ge_to_edge(ge, vnamedict) - src = if ge.src isa Symbol - haskey(vnamedict, ge.src) || throw(ArgumentError("Edge function has unknown or non-unique source vertex name $(ge.src)")) - vnamedict[ge.src] - else - ge.src - end - dst = if ge.dst isa Symbol - haskey(vnamedict, ge.dst) || throw(ArgumentError("Edge function has unknown or non-unique source vertex name $(ge.dst)")) - vnamedict[ge.dst] - else - ge.dst - end - SimpleEdge(src, dst) -end function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose) components = @view im.vertexf[idxs] From d527fad65dc39be929e77378504365b1bec224ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 23 Oct 2024 11:51:58 +0200 Subject: [PATCH 12/17] make dealias optional --- src/construction.jl | 78 +++++++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/src/construction.jl b/src/construction.jl index e62f25e7..841381a8 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -7,20 +7,28 @@ function Network(g::AbstractGraph, aggregator=execution isa SequentialExecution ? SequentialAggregator(+) : PolyesterAggregator(+), check_graphelement=true, set_graphelement=false, + dealias=false, verbose=false) reset_timer!() @timeit_debug "Construct Network" begin # collect all vertex/edgf to vector - _vertexf = vertexf isa Vector ? vertexf : [copy(vertexf) for _ in vertices(g)] - _edgef = edgef isa Vector ? edgef : [copy(edgef) for _ in edges(g)] + maybecopy = dealias ? copy : identity + _vertexf = vertexf isa Vector ? vertexf : [maybecopy(vertexf) for _ in vertices(g)] + _edgef = edgef isa Vector ? edgef : [maybecopy(edgef) for _ in edges(g)] + @argcheck _vertexf isa Vector{<:VertexFunction} "Expected VertexFuncions, got $(eltype(_vertexf))" @argcheck _edgef isa Vector{<:EdgeFunction} "Expected EdgeFuncions, got $(eltype(_vertexf))" @argcheck length(_vertexf) == nv(g) @argcheck length(_edgef) == ne(g) - # check if components alias metadata and copy if necessary - _dealias!(_vertexf) - _dealias!(_edgef) + # check if components alias eachother copy if necessary + # allready dealiase if provided as single functions + if dealias && vertexf isa Vector + dealias!(_vertexf) + end + if dealias && edgef isa Vector + dealias!(_edgef) + end verbose && println("Create dynamic network with $(nv(g)) vertices and $(ne(g)) edges:") @@ -187,53 +195,41 @@ function Network(vertexfs, edgefs; kwargs...) end """ - _dealisas!(cfs::Vector{<:ComponentFunction}) + dealias!(cfs::Vector{<:ComponentFunction}) Checks if any component functions reference the same metadtata/symmetada fields and creates copies of them if necessary. """ -function _dealias!(cfs::Vector{<:ComponentFunction}) - smd_dict = IdDict{Dict{Symbol,Dict{Symbol, Any}},Vector{Int}}() - md_dict = IdDict{Dict{Symbol,Any},Vector{Int}}() +function dealias!(cfs::Vector{<:ComponentFunction}; warn=false) + ag = aliasgroups(cfs) - needscopy = false - for (i, cf) in pairs(cfs) - if haskey(md_dict, metadata(cf)) - needscopy = true - push!(md_dict[metadata(cf)], i) - else - md_dict[metadata(cf)] = [i] - end - if haskey(smd_dict, symmetadata(cf)) - needscopy = true - push!(smd_dict[symmetadata(cf)], i) - else - smd_dict[symmetadata(cf)] = [i] - end - end + isempty(ag) && return cfs # nothign to do - if !needscopy - return cfs - end + copyidxs = reduce(vcat, values(ag)) - copyidxs = Int[] - for v in values(smd_dict) - length(v) > 1 && append!(copyidxs, v) - end - for v in values(md_dict) - length(v) > 1 && append!(copyidxs, v) - end - unique!(copyidxs) + cfs[copyidxs] = copy.(cfs[copyidxs]) +end - comp = first(cfs) isa VertexFunction ? "Vertices" : "Edges" - @warn "$comp $copyidxs reference the same metadata and will be copied. This can happen if \ - the same component reference multiple times. Manually `copy` the component functions to \ - avoid this warning." +""" + aliasgroups(cfs::Vector{<:ComponentFunction}) - cfs[copyidxs] = copy.(cfs[copyidxs]) -end +Returns a dict `cf => idxs` which contains all the component functions +which appear multiple times with all their indices. +""" +function aliasgroups(cfs::Vector{T}) where {T<:ComponentFunction} + d = IdDict{T, Vector{Int}}() + for (i, cf) in pairs(cfs) + if haskey(d, cf) # c allready present + push!(d[cf], i) + else + d[cf] = [i] + end + end + + filter!(x -> length(x.second) > 1, d) +end function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose) components = @view im.vertexf[idxs] From 2bbdd8944b6f7529879e32a9477e7c8f7e989e41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 23 Oct 2024 14:43:10 +0200 Subject: [PATCH 13/17] remove set_graphelement option --- src/construction.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/construction.jl b/src/construction.jl index 841381a8..3f99cb72 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -6,7 +6,6 @@ function Network(g::AbstractGraph, vdepth=:auto, aggregator=execution isa SequentialExecution ? SequentialAggregator(+) : PolyesterAggregator(+), check_graphelement=true, - set_graphelement=false, dealias=false, verbose=false) reset_timer!() @@ -62,11 +61,9 @@ function Network(g::AbstractGraph, if get_graphelement(vf) != i @warn "Vertex function $(vf.name) is placed at node index $i bus has \ `graphelement` $(get_graphelement(vf)) stored in metadata. \ - The wrong data will be " * (set_graphelement ? "overwritten!" : "ignored!") * - " Use `check_graphelement` and `set_graphelement` keywords to alter this behavior." + The wrong data will be ignored! Use `check_graphelement=false` tu supress this warning." end end - set_graphelement && set_graphelement!(vf, i) end for (iteredge, ef) in zip(im.edgevec, _edgef) if check_graphelement && has_graphelement(ef) @@ -75,11 +72,9 @@ function Network(g::AbstractGraph, dst = get(im.unique_vnames, ge.dst, ge.dst) if iteredge.src != src || iteredge.dst != dst @warn "Edge function $(ef.name) at $(iteredge.src) => $(iteredge.dst) has wrong `:graphelement` $src => $dst). \ - The wrong data will be " * (set_graphelement ? "overwritten!" : "ignored!") * - " Use `check_graphelement` and `set_graphelement` keywords to alter this behavior." + The wrong data will be ignored! Use `check_graphelement=false` tu supress this warning." end end - set_graphelement && set_graphelement!(ef, (;src=iteredge.src, dst=iteredge.dst)) end # batch identical edge and vertex functions @@ -191,7 +186,7 @@ function Network(vertexfs, edgefs; kwargs...) vfs_ordered = [vdict[k] for k in vertices(g)] efs_ordered = [edict[k] for k in edges(g)] - Network(g, vfs_ordered, efs_ordered; check_graphelement=false, set_graphelement=false, kwargs...) + Network(g, vfs_ordered, efs_ordered; check_graphelement=false, kwargs...) end """ @@ -385,7 +380,6 @@ function Network(nw::Network; :vdepth => :auto, :aggregator => get_aggr_constructor(nw.layer.aggregator), :check_graphelement => true, - :set_graphelement => false, :verbose => false) for (k, v) in kwargs _kwargs[k] = v From 48c2a213edadaf7879b2747a6848a17699afcc23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 23 Oct 2024 14:44:24 +0200 Subject: [PATCH 14/17] track state of aliased objects track the hashes of aliased vertex/edge functions, such that we can warn when they are changed. For large networks which do not use the meta data machinery this is much better then eagerly copying every component. --- benchmark/benchmarks.jl | 2 +- src/component_functions.jl | 5 +++++ src/construction.jl | 2 +- src/metadata.jl | 41 +++++++++++++++++++++++++++++++++++++ src/network_structure.jl | 26 ++++++++++++++++++++--- src/symbolicindexing.jl | 1 + src/utils.jl | 34 ++++++++++++++++++++++++++++++ test/construction_test.jl | 42 +++++++++++++++++++++++++++++++++++++- 8 files changed, 147 insertions(+), 6 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 19983a9c..3dcbb905 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,6 +1,6 @@ using Pkg -using BenchmarkTools +using Chairmarks using Graphs using NetworkDynamics using Serialization diff --git a/src/component_functions.jl b/src/component_functions.jl index 2b3b43cf..b9d3b1fb 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -623,3 +623,8 @@ but references the same objects everywhere else. $construct end end + +Base.hash(cf::ComponentFunction, h::UInt) = hash_fields(cf, h) +function Base.:(==)(cf1::ComponentFunction, cf2::ComponentFunction) + typeof(cf1) == typeof(cf2) && equal_fields(cf1, cf2) +end diff --git a/src/construction.jl b/src/construction.jl index 3f99cb72..cd26f881 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -52,7 +52,7 @@ function Network(g::AbstractGraph, end # create index manager - im = IndexManager(g, dynstates, edepth, vdepth, _vertexf, _edgef) + im = IndexManager(g, dynstates, edepth, vdepth, _vertexf, _edgef; mightalias=!dealias) # check graph_element metadata and attach if necessary diff --git a/src/metadata.jl b/src/metadata.jl index 5d559db9..b23d2c30 100644 --- a/src/metadata.jl +++ b/src/metadata.jl @@ -180,3 +180,44 @@ end function get_defaults_or_inits(c::ComponentFunction, syms) [has_default_or_init(c, sym) ? get_default_or_init(c, sym) : nothing for sym in syms] end + + +#### +#### Metadata Accessors through Network +#### +function aliased_changed(nw::Network; warn=true) + vchanged = _has_changed_hash(nw.im.aliased_vertexfs) + echanged = _has_changed_hash(nw.im.aliased_edgefs) + changed = vchanged || echanged + if changed && warn + s = if vchanged && echanged + "vertices and edges" + elseif vchanged + "vertices" + else + "edges" + end + @warn """ + The metadata of at least one of your aliased $s changed! Proceed with caution! + + Some edgef/vertexf provided to to the `Network` constructor alias eachother. + Which means, the Network object references the same component function in + multiple places. Thus, metadata changes (such as changing of default values or + component initialization) will be reflected in multiple components. To prevent + this use the `dealias=true` keyword or manualy `copy` edge/vertex functions + before creating the network. + """ + end + changed +end +function _has_changed_hash(aliased_cfs) + isempty(aliased_cfs) && return false + changed = false + for (k, v) in aliased_cfs + if hash(k) != v.hash + changed = true + break + end + end + changed +end diff --git a/src/network_structure.jl b/src/network_structure.jl index fe242b10..2db7e5dd 100644 --- a/src/network_structure.jl +++ b/src/network_structure.jl @@ -22,19 +22,39 @@ mutable struct IndexManager{G} lastidx_gbuf::Int vertexf::Vector{VertexFunction} edgef::Vector{EdgeFunction} + aliased_vertexfs::IdDict{VertexFunction, @NamedTuple{idxs::Vector{Int}, hash::UInt}} + aliased_edgefs::IdDict{EdgeFunction, @NamedTuple{idxs::Vector{Int}, hash::UInt}} unique_vnames::Dict{Symbol,Int} unique_enames::Dict{Symbol,Int} - function IndexManager(g, dyn_states, edepth, vdepth, vertexf, edgef) + function IndexManager(g, dyn_states, edepth, vdepth, vertexf, edgef; mightalias) + aliased_vertexf_hashes = _aliased_hashes(VertexFunction, vertexf, mightalias) + aliased_edgef_hashes = _aliased_hashes(EdgeFunction, edgef, mightalias) + unique_vnames = unique_mappings(getproperty.(vertexf, :name), 1:nv(g)) + unique_enames = unique_mappings(getproperty.(edgef, :name), 1:ne(g)) new{typeof(g)}(g, collect(edges(g)), (Vector{UnitRange{Int}}(undef, nv(g)) for i in 1:3)..., (Vector{UnitRange{Int}}(undef, ne(g)) for i in 1:5)..., edepth, vdepth, 0, dyn_states, 0, 0, 0, vertexf, edgef, - unique_mappings(getproperty.(vertexf, :name), 1:nv(g)), - unique_mappings(getproperty.(edgef, :name), 1:ne(g))) + aliased_vertexf_hashes, + aliased_edgef_hashes, + unique_vnames, + unique_enames) end end +function _aliased_hashes(T, cfs, mightalias) + hashdict = IdDict{T, @NamedTuple{idxs::Vector{Int}, hash::UInt}}() + if mightalias + ag = aliasgroups(cfs) + for (c, idxs) in ag + h = hash(c) + hashdict[c] = (; idxs=idxs, hash=h) + end + end + hashdict +end + dim(im::IndexManager) = im.lastidx_dynamic pdim(im::IndexManager) = im.lastidx_p sdim(im::IndexManager) = im.lastidx_static - im.lastidx_dynamic diff --git a/src/symbolicindexing.jl b/src/symbolicindexing.jl index 94b28b36..893a33b3 100644 --- a/src/symbolicindexing.jl +++ b/src/symbolicindexing.jl @@ -537,6 +537,7 @@ end #### Default values #### function SII.default_values(nw::Network) + aliased_changed(nw; warn=true) defs = Dict{SymbolicIndex{Int,Symbol},Float64}() for (ci, cf) in pairs(nw.im.vertexf) for s in psym(cf) diff --git a/src/utils.jl b/src/utils.jl index ba5f0ad4..7f38181e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -47,3 +47,37 @@ function unique_mappings(f, from, to) end unique end + + +""" + hash_fields(obj::T, h) + +This is @generated helper function which unrolls all fields of a struct `obj` and +recursively hashes them. +""" +@generated function hash_fields(obj::T, h::UInt) where {T} + fields = fieldnames(obj) + subhashes = Expr(:block, (:(h = hash(obj.$field, h)) for field in fields)...) + + quote + h = hash(T, h) + $subhashes + h + end +end + +""" + equal_fields(a::T, b::T) where {T} + +Thise @generated helper function unrolls all fields of two structs `a` and `b` and +compares them. +""" +@generated function equal_fields(a::T, b::T) where {T} + fields = fieldnames(T) + subequals = Expr(:block, (:(a.$field == b.$field || return false) for field in fields)...) + + quote + $subequals + return true + end +end diff --git a/test/construction_test.jl b/test/construction_test.jl index 07ca0f57..f220a888 100644 --- a/test/construction_test.jl +++ b/test/construction_test.jl @@ -271,6 +271,7 @@ end end @testset "test dealias and copy of components" begin + using NetworkDynamics: aliasgroups v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) v2 = ODEVertex(x->x^2, 2, 0; name=:v2, vidx=2) v3 = ODEVertex(x->x^3, 2, 0; name=:v3, vidx=3) @@ -279,19 +280,58 @@ end e2 = StaticEdge(nothing, 0, Symmetric(); src=:v2, dst=:v3) e3 = StaticEdge(nothing, 0, Symmetric(); src=:v3, dst=:v1) + @test isempty(aliasgroups([v1,v2,v3])) + @test aliasgroups([v1, v1, v3]) == IdDict(v1 => [1,2]) + @test aliasgroups([v3, v1, v3]) == IdDict(v3 => [1,3]) + g = complete_graph(3) - nw = Network(g, [v1,v1,v3],[e1,e2,e3]) + # with dealiasing / copying + nw = Network(g, [v1,v1,v3],[e1,e2,e3]; dealias=true, check_graphelement=false) @test nw.im.vertexf[1].f == nw.im.vertexf[2].f @test nw.im.vertexf[1].metadata == nw.im.vertexf[2].metadata @test nw.im.vertexf[1].metadata !== nw.im.vertexf[2].metadata @test nw.im.vertexf[1].symmetadata == nw.im.vertexf[2].symmetadata @test nw.im.vertexf[1].symmetadata !== nw.im.vertexf[2].symmetadata + s0 = NWState(nw) + @test isnan(s0.v[1,1]) + set_default!(nw.im.vertexf[1], :v₁, 3) + s1 = NWState(nw) + @test s1.v[1,1] == 3 + + # witout dealisasing + nw = Network(g, [v1,v1,v3],[e1,e2,e1]; check_graphelement=false) + @test nw.im.vertexf[1] === nw.im.vertexf[2] + @test nw.im.edgef[1] === nw.im.edgef[3] + @test keys(nw.im.aliased_vertexfs) == Set([v1]) + @test keys(nw.im.aliased_edgefs) == Set([e1]) + @test only(unique(values(nw.im.aliased_vertexfs))) == (;idxs=[1,2], hash=hash(v1)) + @test only(unique(values(nw.im.aliased_edgefs))) == (;idxs=[1,3], hash=hash(e1)) + s0 = NWState(nw) + @test isnan(s0.v[1,1]) + prehash = hash(v1) + set_default!(v1, :v₁, 3) + posthash = hash(v1) + @test prehash !== posthash + @test NetworkDynamics.aliased_changed(nw; warn=false) + s1 = NWState(nw) + + # test hashes + h1 = hash(v1) + set_metadata!(v1, :foo, :bar) + h2 = hash(v1) + @test h1 !== h2 v = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), symmetadata=Dict(:x=>Dict(:default=>1))) v2 = copy(v) + @test v !== v2 + @test v == v2 + @test isequal(v, v2) @test get_default(v, :x) == get_default(v2, :x) set_default!(v, :x, 99) + @test get_default(v, :x) == 99 @test get_default(v2, :x) == 1 + @test v != v2 + @test v1 != v2 end @testset "test network-remake constructor" begin From d135cef6e7e342f4e5352988b2febe339e2cdf6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 23 Oct 2024 14:51:56 +0200 Subject: [PATCH 15/17] add dealias argument to remake constructor --- src/construction.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/construction.jl b/src/construction.jl index cd26f881..9fcc7425 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -380,6 +380,7 @@ function Network(nw::Network; :vdepth => :auto, :aggregator => get_aggr_constructor(nw.layer.aggregator), :check_graphelement => true, + :dealias => false, :verbose => false) for (k, v) in kwargs _kwargs[k] = v From 40e875125c2bddadde6752f21ca79013da251352 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 23 Oct 2024 14:57:42 +0200 Subject: [PATCH 16/17] rm redudant test --- test/construction_test.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/construction_test.jl b/test/construction_test.jl index f220a888..98421ad7 100644 --- a/test/construction_test.jl +++ b/test/construction_test.jl @@ -315,12 +315,7 @@ end @test NetworkDynamics.aliased_changed(nw; warn=false) s1 = NWState(nw) - # test hashes - h1 = hash(v1) - set_metadata!(v1, :foo, :bar) - h2 = hash(v1) - @test h1 !== h2 - + # test copy v = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), symmetadata=Dict(:x=>Dict(:default=>1))) v2 = copy(v) @test v !== v2 From 597d32a078a343fa8074e71ea35cd7bd9aa64b8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 23 Oct 2024 16:23:53 +0200 Subject: [PATCH 17/17] improve constructor for single edges/vertices given --- src/NetworkDynamics.jl | 2 +- src/construction.jl | 66 ++++++++++++++++++++++++++-------------- src/network_structure.jl | 13 +++++--- 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/src/NetworkDynamics.jl b/src/NetworkDynamics.jl index bc52cf56..d5e2df16 100644 --- a/src/NetworkDynamics.jl +++ b/src/NetworkDynamics.jl @@ -1,7 +1,7 @@ module NetworkDynamics using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv, SimpleGraph, SimpleDiGraph, add_edge!, has_edge -using TimerOutputs: @timeit_debug, reset_timer! +using TimerOutputs: TimerOutputs, @timeit_debug, @timeit using ArgCheck: @argcheck using PreallocationTools: PreallocationTools, DiffCache, get_tmp diff --git a/src/construction.jl b/src/construction.jl index 9fcc7425..4b36a608 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -8,12 +8,14 @@ function Network(g::AbstractGraph, check_graphelement=true, dealias=false, verbose=false) - reset_timer!() + # TimerOutputs.reset_timer!() @timeit_debug "Construct Network" begin # collect all vertex/edgf to vector + all_same_v = vertexf isa VertexFunction + all_same_e = edgef isa EdgeFunction maybecopy = dealias ? copy : identity - _vertexf = vertexf isa Vector ? vertexf : [maybecopy(vertexf) for _ in vertices(g)] - _edgef = edgef isa Vector ? edgef : [maybecopy(edgef) for _ in edges(g)] + _vertexf = all_same_v ? [maybecopy(vertexf) for _ in vertices(g)] : vertexf + _edgef = all_same_e ? [maybecopy(edgef) for _ in edges(g)] : edgef @argcheck _vertexf isa Vector{<:VertexFunction} "Expected VertexFuncions, got $(eltype(_vertexf))" @argcheck _edgef isa Vector{<:EdgeFunction} "Expected EdgeFuncions, got $(eltype(_vertexf))" @@ -22,10 +24,10 @@ function Network(g::AbstractGraph, # check if components alias eachother copy if necessary # allready dealiase if provided as single functions - if dealias && vertexf isa Vector + if dealias && !all_same_v dealias!(_vertexf) end - if dealias && edgef isa Vector + if dealias && !all_same_e dealias!(_edgef) end @@ -52,27 +54,45 @@ function Network(g::AbstractGraph, end # create index manager - im = IndexManager(g, dynstates, edepth, vdepth, _vertexf, _edgef; mightalias=!dealias) + @timeit_debug "Construct Index manager" begin + valias = dealias ? :none : (all_same_v ? :all : :some) + ealias = dealias ? :none : (all_same_e ? :all : :some) + im = IndexManager(g, dynstates, edepth, vdepth, _vertexf, _edgef; valias, ealias) + end # check graph_element metadata and attach if necessary - for (i, vf) in pairs(_vertexf) - if check_graphelement && has_graphelement(vf) - if get_graphelement(vf) != i - @warn "Vertex function $(vf.name) is placed at node index $i bus has \ - `graphelement` $(get_graphelement(vf)) stored in metadata. \ - The wrong data will be ignored! Use `check_graphelement=false` tu supress this warning." + if check_graphelement + @timeit_debug "Check graph element" begin + if !all_same_v + for (i, vf) in pairs(_vertexf) + if has_graphelement(vf) + if get_graphelement(vf) != i + @warn "Vertex function $(vf.name) is placed at node index $i bus has \ + `graphelement` $(get_graphelement(vf)) stored in metadata. \ + The wrong data will be ignored! Use `check_graphelement=false` tu supress this warning." + end + end + end + elseif has_graphelement(vertexf) + @warn "Provided vertex function has assigned `graphelement` metadata. \ + but is used at every vertex. The `graphelement` will be ignored." end - end - end - for (iteredge, ef) in zip(im.edgevec, _edgef) - if check_graphelement && has_graphelement(ef) - ge = get_graphelement(ef) - src = get(im.unique_vnames, ge.src, ge.src) - dst = get(im.unique_vnames, ge.dst, ge.dst) - if iteredge.src != src || iteredge.dst != dst - @warn "Edge function $(ef.name) at $(iteredge.src) => $(iteredge.dst) has wrong `:graphelement` $src => $dst). \ - The wrong data will be ignored! Use `check_graphelement=false` tu supress this warning." + if !all_same_e + for (iteredge, ef) in zip(im.edgevec, _edgef) + if has_graphelement(ef) + ge = get_graphelement(ef) + src = get(im.unique_vnames, ge.src, ge.src) + dst = get(im.unique_vnames, ge.dst, ge.dst) + if iteredge.src != src || iteredge.dst != dst + @warn "Edge function $(ef.name) at $(iteredge.src) => $(iteredge.dst) has wrong `:graphelement` $src => $dst). \ + The wrong data will be ignored! Use `check_graphelement=false` tu supress this warning." + end + end + end + elseif has_graphelement(edgef) + @warn "Provided edge function has assigned `graphelement` metadata. \ + but is used for all edges. The `graphelement` will be ignored." end end end @@ -140,7 +160,7 @@ function Network(g::AbstractGraph, ) end - # print_timer() + # TimerOutputs.print_timer() return nw end diff --git a/src/network_structure.jl b/src/network_structure.jl index 2db7e5dd..f5e310d1 100644 --- a/src/network_structure.jl +++ b/src/network_structure.jl @@ -26,9 +26,9 @@ mutable struct IndexManager{G} aliased_edgefs::IdDict{EdgeFunction, @NamedTuple{idxs::Vector{Int}, hash::UInt}} unique_vnames::Dict{Symbol,Int} unique_enames::Dict{Symbol,Int} - function IndexManager(g, dyn_states, edepth, vdepth, vertexf, edgef; mightalias) - aliased_vertexf_hashes = _aliased_hashes(VertexFunction, vertexf, mightalias) - aliased_edgef_hashes = _aliased_hashes(EdgeFunction, edgef, mightalias) + function IndexManager(g, dyn_states, edepth, vdepth, vertexf, edgef; valias, ealias) + aliased_vertexf_hashes = _aliased_hashes(VertexFunction, vertexf, valias) + aliased_edgef_hashes = _aliased_hashes(EdgeFunction, edgef, ealias) unique_vnames = unique_mappings(getproperty.(vertexf, :name), 1:nv(g)) unique_enames = unique_mappings(getproperty.(edgef, :name), 1:ne(g)) new{typeof(g)}(g, collect(edges(g)), @@ -43,14 +43,17 @@ mutable struct IndexManager{G} unique_enames) end end -function _aliased_hashes(T, cfs, mightalias) +function _aliased_hashes(T, cfs, aliastype) hashdict = IdDict{T, @NamedTuple{idxs::Vector{Int}, hash::UInt}}() - if mightalias + if aliastype == :some ag = aliasgroups(cfs) for (c, idxs) in ag h = hash(c) hashdict[c] = (; idxs=idxs, hash=h) end + elseif aliastype == :all + c = first(cfs) + hashdict[c] =(; idxs=collect(eachindex(cfs)), hash=hash(c)) end hashdict end