Skip to content

Commit

Permalink
Merge pull request #164 from JuliaDynamics/hw/metadata
Browse files Browse the repository at this point in the history
metadata overhaul
  • Loading branch information
hexaeder authored Oct 23, 2024
2 parents 3e99370 + 597d32a commit ae9fb7a
Show file tree
Hide file tree
Showing 14 changed files with 668 additions and 327 deletions.
2 changes: 1 addition & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg

using BenchmarkTools
using Chairmarks
using Graphs
using NetworkDynamics
using Serialization
Expand Down
16 changes: 9 additions & 7 deletions src/NetworkDynamics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module NetworkDynamics
using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv,
SimpleGraph, SimpleDiGraph, add_edge!, has_edge
using TimerOutputs: @timeit_debug, reset_timer!
using TimerOutputs: TimerOutputs, @timeit_debug, @timeit

using ArgCheck: @argcheck
using PreallocationTools: PreallocationTools, DiffCache, get_tmp
Expand Down Expand Up @@ -39,12 +39,6 @@ export ODEVertex, StaticVertex, StaticEdge, ODEEdge
export Symmetric, AntiSymmetric, Directed, Fiducial
export dim, sym, pdim, psym, obssym, depth, hasinputsym, inputsym, coupling
export metadata, symmetadata
export has_metadata, get_metadata, set_metadata!
export has_default, get_default, set_default!
export has_guess, get_guess, set_guess!
export has_init, get_init, set_init!
export has_bounds, get_bounds, set_bounds!
export has_graphelement, get_graphelement, set_graphelement!
include("component_functions.jl")

export Network
Expand All @@ -64,6 +58,14 @@ export vidxs, eidxs, vpidxs, epidxs
export save_parameters!
include("symbolicindexing.jl")

export has_metadata, get_metadata, set_metadata!
export has_default, get_default, set_default!
export has_guess, get_guess, set_guess!
export has_init, get_init, set_init!
export has_bounds, get_bounds, set_bounds!
export has_graphelement, get_graphelement, set_graphelement!
include("metadata.jl")

using NonlinearSolve: AbstractNonlinearSolveAlgorithm, NonlinearFunction
using NonlinearSolve: NonlinearLeastSquaresProblem, NonlinearProblem
using SteadyStateDiffEq: SteadyStateProblem, SteadyStateDiffEqAlgorithm, SSRootfind
Expand Down
8 changes: 8 additions & 0 deletions src/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ function aggregate!(a::SparseAggregator, aggbuf, data)
nothing
end

# functions to retrieve the constructor of an aggregator for remake of network
get_aggr_constructor(a::NaiveAggregator) = NaiveAggregator(a.f)
get_aggr_constructor(a::KAAggregator) = KAAggregator(a.f)
get_aggr_constructor(a::SequentialAggregator) = SequentialAggregator(a.f)
get_aggr_constructor(a::PolyesterAggregator) = PolyesterAggregator(a.f)
get_aggr_constructor(a::ThreadedAggregator) = ThreadedAggregator(a.f)
get_aggr_constructor(a::SparseAggregator) = SparseAggregator(+)

iscudacompatible(::Type{<:Aggregator}) = false
iscudacompatible(::Type{<:KAAggregator}) = true
iscudacompatible(::Type{<:SparseAggregator}) = true
253 changes: 35 additions & 218 deletions src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ ODEVertex(; kwargs...) = _construct_comp(ODEVertex, kwargs)
ODEVertex(f; kwargs...) = ODEVertex(;f, kwargs...)
ODEVertex(f, dim; kwargs...) = ODEVertex(;f, _dimsym(dim)..., kwargs...)
ODEVertex(f, dim, pdim; kwargs...) = ODEVertex(;f, _dimsym(dim, pdim)..., kwargs...)
ODEVertex(v::ODEVertex; kwargs...) = _reconstruct_comp(ODEVertex, v, kwargs)

struct StaticVertex{F,OF} <: VertexFunction
@CommonFields
Expand All @@ -198,6 +199,7 @@ StaticVertex(; kwargs...) = _construct_comp(StaticVertex, kwargs)
StaticVertex(f; kwargs...) = StaticVertex(;f, kwargs...)
StaticVertex(f, dim; kwargs...) = StaticVertex(;f, _dimsym(dim)..., kwargs...)
StaticVertex(f, dim, pdim; kwargs...) = StaticVertex(;f, _dimsym(dim, pdim)..., kwargs...)
StaticVertex(v::StaticVertex; kwargs...) = _reconstruct_comp(StaticVertex, v, kwargs)
function ODEVertex(sv::StaticVertex)
d = Dict{Symbol,Any}()
for prop in propertynames(sv)
Expand All @@ -224,6 +226,7 @@ StaticEdge(; kwargs...) = _construct_comp(StaticEdge, kwargs)
StaticEdge(f; kwargs...) = StaticEdge(;f, kwargs...)
StaticEdge(f, dim, coupling; kwargs...) = StaticEdge(;f, _dimsym(dim)..., coupling, kwargs...)
StaticEdge(f, dim, pdim, coupling; kwargs...) = StaticEdge(;f, _dimsym(dim, pdim)..., coupling, kwargs...)
StaticEdge(e::StaticEdge; kwargs...) = _reconstruct_comp(StaticEdge, e, kwargs)

struct ODEEdge{C,F,OF,MM} <: EdgeFunction{C}
@CommonFields
Expand All @@ -234,6 +237,7 @@ ODEEdge(; kwargs...) = _construct_comp(ODEEdge, kwargs)
ODEEdge(f; kwargs...) = ODEEdge(;f, kwargs...)
ODEEdge(f, dim, coupling; kwargs...) = ODEEdge(;f, _dimsym(dim)..., coupling, kwargs...)
ODEEdge(f, dim, pdim, coupling; kwargs...) = ODEEdge(;f, _dimsym(dim, pdim)..., coupling, kwargs...)
ODEEdge(e::ODEEdge; kwargs...) = _reconstruct_comp(ODEEdge, e, kwargs)

statetype(::T) where {T<:ComponentFunction} = statetype(T)
statetype(::Type{<:ODEVertex}) = Dynamic()
Expand Down Expand Up @@ -313,6 +317,18 @@ function _construct_comp(::Type{T}, kwargs) where {T}
return c
end

function _reconstruct_comp(::Type{T}, cf::ComponentFunction, kwargs) where {T}
fields = fieldnames(T)
dict = Dict{Symbol, Any}()
for f in fields
dict[f] = getproperty(cf, f)
end
for (k, v) in kwargs
dict[k] = v
end
_construct_comp(T, dict)
end

"""
_fill_defaults(T, kwargs)
Expand Down Expand Up @@ -585,229 +601,30 @@ _valid_signature(::Type{<:ODEEdge}, f) = _takes_n_vectors(f, 5) #(du, u, src, ds

_takes_n_vectors(f, n) = hasmethod(f, (Tuple(Vector{Float64} for i in 1:n)..., Float64))


####
#### per sym metadata
####
"""
has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol)
Checks if symbol metadata `key` is present for symbol `sym`.
"""
function has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol)
md = symmetadata(c)
haskey(md, sym) && haskey(md[sym], key)
end
"""
get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol)
Retrievs the metadata `key` for symbol `sym`.
"""
get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) = symmetadata(c)[sym][key]

"""
set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value)
set_metadata!(c::ComponentFunction, sym::Symbol, pair)
Sets the metadata `key` for symbol `sym` to `value`.
"""
function set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value)
d = get!(symmetadata(c), sym, Dict{Symbol,Any}())
d[key] = value
end
set_metadata!(c::ComponentFunction, sym::Symbol, pair::Pair) = set_metadata!(c, sym, pair.first, pair.second)

#### default
"""
has_default(c::ComponentFunction, sym::Symbol)
Checks if a `default` value is present for symbol `sym`.
"""
has_default(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :default)
"""
get_default(c::ComponentFunction, sym::Symbol)
Returns the `default` value for symbol `sym`.
"""
get_default(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :default)
"""
set_default!(c::ComponentFunction, sym::Symbol, value)
Sets the `default` value for symbol `sym` to `value`.
"""
set_default!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :default, value)

#### guess
"""
has_guess(c::ComponentFunction, sym::Symbol)
Checks if a `guess` value is present for symbol `sym`.
"""
has_guess(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :guess)
"""
get_guess(c::ComponentFunction, sym::Symbol)
Returns the `guess` value for symbol `sym`.
"""
get_guess(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :guess)
"""
set_guess!(c::ComponentFunction, sym::Symbol, value)
Sets the `guess` value for symbol `sym` to `value`.
"""
set_guess!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :guess, value)

#### init
"""
has_init(c::ComponentFunction, sym::Symbol)
Checks if a `init` value is present for symbol `sym`.
"""
has_init(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :init)
"""
get_init(c::ComponentFunction, sym::Symbol)
Returns the `init` value for symbol `sym`.
"""
get_init(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :init)
"""
set_init!(c::ComponentFunction, sym::Symbol, value)
Sets the `init` value for symbol `sym` to `value`.
"""
set_init!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :init, value)

#### bounds
"""
has_bounds(c::ComponentFunction, sym::Symbol)
Checks if a `bounds` value is present for symbol `sym`.
"""
has_bounds(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :bounds)
"""
get_bounds(c::ComponentFunction, sym::Symbol)
Returns the `bounds` value for symbol `sym`.
"""
get_bounds(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :bounds)
"""
set_bounds!(c::ComponentFunction, sym::Symbol, value)
Sets the `bounds` value for symbol `sym` to `value`.
"""
set_bounds!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :bounds, value)

copy(c::NetworkDynamics.ComponentFunction)
#### default or init
Shallow copy of the component function. Creates a deepcopy of `metadata` and `symmetadata`
but references the same objects everywhere else.
"""
has_default_or_init(c::ComponentFunction, sym::Symbol)
@generated function Base.copy(c::ComponentFunction)
fields = fieldnames(c)
# fields to copy
cfields = (:metadata, :symmetadata)
# normal fields
nfields = setdiff(fields, cfields)
assign = Expr(:block,
(:($(field) = c.$field) for field in nfields)...,
(:($(field) = deepcopy(c.$field)) for field in cfields)...)
construct = Expr(:call, c, [:($field) for field in fields]...)

Checks if a `default` or `init` value is present for symbol `sym`.
"""
has_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_init(c, sym)
"""
get_default_or_init(c::ComponentFunction, sym::Symbol)
Returns if a `default` value if available, otherwise returns `init` value for symbol `sym`.
"""
get_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_init(c, sym)

#### default or guess
"""
has_default_or_guess(c::ComponentFunction, sym::Symbol)
Checks if a `default` or `guess` value is present for symbol `sym`.
"""
has_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_guess(c, sym)
"""
get_default_or_guess(c::ComponentFunction, sym::Symbol)
Returns if a `default` value if available, otherwise returns `guess` value for symbol `sym`.
"""
get_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_guess(c, sym)


# TODO: legacy, only used within show methods
function def(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.sym) do s
has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing
end
end
function guess(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.sym) do s
has_guess(c, s) ? get_guess(c, s) : nothing
end
end
function pdef(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.psym) do s
has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing
end
end
function pguess(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.psym) do s
has_guess(c, s) ? get_guess(c, s) : nothing
quote
$assign
$construct
end
end

####
#### Component metadata
####
"""
has_metadata(c::ComponentFunction, key::Symbol)
Checks if metadata `key` is present for the component.
"""
function has_metadata(c::ComponentFunction, key)
haskey(metadata(c), key)
end
"""
get_metadata(c::ComponentFunction, key::Symbol)
Retrieves the metadata `key` for the component.
"""
get_metadata(c::ComponentFunction, key::Symbol) = metadata(c)[key]
"""
set_metadata!(c::ComponentFunction, key::Symbol, value)
Sets the metadata `key` for the component to `value`.
"""
set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), val, key)

#### graphelement field for edges and vertices
"""
has_graphelement(c)
Checks if the edge or vetex function function has the `graphelement` metadata.
"""
has_graphelement(c::ComponentFunction) = has_metadata(c, :graphelement)
"""
get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T}
get_graphelement(c::VertexFunction)::Int
Retrieves the `graphelement` metadata for the component function. For edges this
returns a named tupe `(;src, dst)` where both are either integers (vertex index)
or symbols (vertex name).
"""
get_graphelement(c::EdgeFunction) = get_metadata(c, :graphelement)::@NamedTuple{src::T, dst::T} where {T<:Union{Int,Symbol}}
get_graphelement(c::VertexFunction) = get_metadata(c, :graphelement)::Int
"""
set_graphelement!(c::EdgeFunction, src, dst)
set_graphelement!(c::VertexFunction, vidx)
Sets the `graphelement` metadata for the edge function. For edges this takes two
arguments `src` and `dst` which are either integer (vertex index) or symbol
(vertex name). For vertices it takes a single integer `vidx`.
"""
set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt)
set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx)


function get_defaults(c::ComponentFunction, syms)
[has_default(c, sym) ? get_default(c, sym) : nothing for sym in syms]
end
function get_guesses(c::ComponentFunction, syms)
[has_guess(c, sym) ? get_guess(c, sym) : nothing for sym in syms]
end
function get_defaults_or_inits(c::ComponentFunction, syms)
[has_default_or_init(c, sym) ? get_default_or_init(c, sym) : nothing for sym in syms]
Base.hash(cf::ComponentFunction, h::UInt) = hash_fields(cf, h)
function Base.:(==)(cf1::ComponentFunction, cf2::ComponentFunction)
typeof(cf1) == typeof(cf2) && equal_fields(cf1, cf2)
end
Loading

0 comments on commit ae9fb7a

Please sign in to comment.