Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metadata overhaul #164

Merged
merged 17 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg

using BenchmarkTools
using Chairmarks
using Graphs
using NetworkDynamics
using Serialization
Expand Down
16 changes: 9 additions & 7 deletions src/NetworkDynamics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module NetworkDynamics
using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv,
SimpleGraph, SimpleDiGraph, add_edge!, has_edge
using TimerOutputs: @timeit_debug, reset_timer!
using TimerOutputs: TimerOutputs, @timeit_debug, @timeit

using ArgCheck: @argcheck
using PreallocationTools: PreallocationTools, DiffCache, get_tmp
Expand Down Expand Up @@ -39,12 +39,6 @@ export ODEVertex, StaticVertex, StaticEdge, ODEEdge
export Symmetric, AntiSymmetric, Directed, Fiducial
export dim, sym, pdim, psym, obssym, depth, hasinputsym, inputsym, coupling
export metadata, symmetadata
export has_metadata, get_metadata, set_metadata!
export has_default, get_default, set_default!
export has_guess, get_guess, set_guess!
export has_init, get_init, set_init!
export has_bounds, get_bounds, set_bounds!
export has_graphelement, get_graphelement, set_graphelement!
include("component_functions.jl")

export Network
Expand All @@ -64,6 +58,14 @@ export vidxs, eidxs, vpidxs, epidxs
export save_parameters!
include("symbolicindexing.jl")

export has_metadata, get_metadata, set_metadata!
export has_default, get_default, set_default!
export has_guess, get_guess, set_guess!
export has_init, get_init, set_init!
export has_bounds, get_bounds, set_bounds!
export has_graphelement, get_graphelement, set_graphelement!
include("metadata.jl")

using NonlinearSolve: AbstractNonlinearSolveAlgorithm, NonlinearFunction
using NonlinearSolve: NonlinearLeastSquaresProblem, NonlinearProblem
using SteadyStateDiffEq: SteadyStateProblem, SteadyStateDiffEqAlgorithm, SSRootfind
Expand Down
8 changes: 8 additions & 0 deletions src/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ function aggregate!(a::SparseAggregator, aggbuf, data)
nothing
end

# functions to retrieve the constructor of an aggregator for remake of network
get_aggr_constructor(a::NaiveAggregator) = NaiveAggregator(a.f)
get_aggr_constructor(a::KAAggregator) = KAAggregator(a.f)
get_aggr_constructor(a::SequentialAggregator) = SequentialAggregator(a.f)
get_aggr_constructor(a::PolyesterAggregator) = PolyesterAggregator(a.f)
get_aggr_constructor(a::ThreadedAggregator) = ThreadedAggregator(a.f)
get_aggr_constructor(a::SparseAggregator) = SparseAggregator(+)

iscudacompatible(::Type{<:Aggregator}) = false
iscudacompatible(::Type{<:KAAggregator}) = true
iscudacompatible(::Type{<:SparseAggregator}) = true
253 changes: 35 additions & 218 deletions src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ ODEVertex(; kwargs...) = _construct_comp(ODEVertex, kwargs)
ODEVertex(f; kwargs...) = ODEVertex(;f, kwargs...)
ODEVertex(f, dim; kwargs...) = ODEVertex(;f, _dimsym(dim)..., kwargs...)
ODEVertex(f, dim, pdim; kwargs...) = ODEVertex(;f, _dimsym(dim, pdim)..., kwargs...)
ODEVertex(v::ODEVertex; kwargs...) = _reconstruct_comp(ODEVertex, v, kwargs)

struct StaticVertex{F,OF} <: VertexFunction
@CommonFields
Expand All @@ -198,6 +199,7 @@ StaticVertex(; kwargs...) = _construct_comp(StaticVertex, kwargs)
StaticVertex(f; kwargs...) = StaticVertex(;f, kwargs...)
StaticVertex(f, dim; kwargs...) = StaticVertex(;f, _dimsym(dim)..., kwargs...)
StaticVertex(f, dim, pdim; kwargs...) = StaticVertex(;f, _dimsym(dim, pdim)..., kwargs...)
StaticVertex(v::StaticVertex; kwargs...) = _reconstruct_comp(StaticVertex, v, kwargs)
function ODEVertex(sv::StaticVertex)
d = Dict{Symbol,Any}()
for prop in propertynames(sv)
Expand All @@ -224,6 +226,7 @@ StaticEdge(; kwargs...) = _construct_comp(StaticEdge, kwargs)
StaticEdge(f; kwargs...) = StaticEdge(;f, kwargs...)
StaticEdge(f, dim, coupling; kwargs...) = StaticEdge(;f, _dimsym(dim)..., coupling, kwargs...)
StaticEdge(f, dim, pdim, coupling; kwargs...) = StaticEdge(;f, _dimsym(dim, pdim)..., coupling, kwargs...)
StaticEdge(e::StaticEdge; kwargs...) = _reconstruct_comp(StaticEdge, e, kwargs)

struct ODEEdge{C,F,OF,MM} <: EdgeFunction{C}
@CommonFields
Expand All @@ -234,6 +237,7 @@ ODEEdge(; kwargs...) = _construct_comp(ODEEdge, kwargs)
ODEEdge(f; kwargs...) = ODEEdge(;f, kwargs...)
ODEEdge(f, dim, coupling; kwargs...) = ODEEdge(;f, _dimsym(dim)..., coupling, kwargs...)
ODEEdge(f, dim, pdim, coupling; kwargs...) = ODEEdge(;f, _dimsym(dim, pdim)..., coupling, kwargs...)
ODEEdge(e::ODEEdge; kwargs...) = _reconstruct_comp(ODEEdge, e, kwargs)

statetype(::T) where {T<:ComponentFunction} = statetype(T)
statetype(::Type{<:ODEVertex}) = Dynamic()
Expand Down Expand Up @@ -313,6 +317,18 @@ function _construct_comp(::Type{T}, kwargs) where {T}
return c
end

function _reconstruct_comp(::Type{T}, cf::ComponentFunction, kwargs) where {T}
fields = fieldnames(T)
dict = Dict{Symbol, Any}()
for f in fields
dict[f] = getproperty(cf, f)
end
for (k, v) in kwargs
dict[k] = v
end
_construct_comp(T, dict)
end

"""
_fill_defaults(T, kwargs)

Expand Down Expand Up @@ -585,229 +601,30 @@ _valid_signature(::Type{<:ODEEdge}, f) = _takes_n_vectors(f, 5) #(du, u, src, ds

_takes_n_vectors(f, n) = hasmethod(f, (Tuple(Vector{Float64} for i in 1:n)..., Float64))


####
#### per sym metadata
####
"""
has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol)

Checks if symbol metadata `key` is present for symbol `sym`.
"""
function has_metadata(c::ComponentFunction, sym::Symbol, key::Symbol)
md = symmetadata(c)
haskey(md, sym) && haskey(md[sym], key)
end
"""
get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol)

Retrievs the metadata `key` for symbol `sym`.
"""
get_metadata(c::ComponentFunction, sym::Symbol, key::Symbol) = symmetadata(c)[sym][key]

"""
set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value)
set_metadata!(c::ComponentFunction, sym::Symbol, pair)

Sets the metadata `key` for symbol `sym` to `value`.
"""
function set_metadata!(c::ComponentFunction, sym::Symbol, key::Symbol, value)
d = get!(symmetadata(c), sym, Dict{Symbol,Any}())
d[key] = value
end
set_metadata!(c::ComponentFunction, sym::Symbol, pair::Pair) = set_metadata!(c, sym, pair.first, pair.second)

#### default
"""
has_default(c::ComponentFunction, sym::Symbol)

Checks if a `default` value is present for symbol `sym`.
"""
has_default(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :default)
"""
get_default(c::ComponentFunction, sym::Symbol)

Returns the `default` value for symbol `sym`.
"""
get_default(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :default)
"""
set_default!(c::ComponentFunction, sym::Symbol, value)

Sets the `default` value for symbol `sym` to `value`.
"""
set_default!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :default, value)

#### guess
"""
has_guess(c::ComponentFunction, sym::Symbol)

Checks if a `guess` value is present for symbol `sym`.
"""
has_guess(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :guess)
"""
get_guess(c::ComponentFunction, sym::Symbol)

Returns the `guess` value for symbol `sym`.
"""
get_guess(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :guess)
"""
set_guess!(c::ComponentFunction, sym::Symbol, value)

Sets the `guess` value for symbol `sym` to `value`.
"""
set_guess!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :guess, value)

#### init
"""
has_init(c::ComponentFunction, sym::Symbol)

Checks if a `init` value is present for symbol `sym`.
"""
has_init(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :init)
"""
get_init(c::ComponentFunction, sym::Symbol)

Returns the `init` value for symbol `sym`.
"""
get_init(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :init)
"""
set_init!(c::ComponentFunction, sym::Symbol, value)

Sets the `init` value for symbol `sym` to `value`.
"""
set_init!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :init, value)

#### bounds
"""
has_bounds(c::ComponentFunction, sym::Symbol)

Checks if a `bounds` value is present for symbol `sym`.
"""
has_bounds(c::ComponentFunction, sym::Symbol) = has_metadata(c, sym, :bounds)
"""
get_bounds(c::ComponentFunction, sym::Symbol)

Returns the `bounds` value for symbol `sym`.
"""
get_bounds(c::ComponentFunction, sym::Symbol) = get_metadata(c, sym, :bounds)
"""
set_bounds!(c::ComponentFunction, sym::Symbol, value)

Sets the `bounds` value for symbol `sym` to `value`.
"""
set_bounds!(c::ComponentFunction, sym::Symbol, value) = set_metadata!(c, sym, :bounds, value)

copy(c::NetworkDynamics.ComponentFunction)

#### default or init
Shallow copy of the component function. Creates a deepcopy of `metadata` and `symmetadata`
but references the same objects everywhere else.
"""
has_default_or_init(c::ComponentFunction, sym::Symbol)
@generated function Base.copy(c::ComponentFunction)
fields = fieldnames(c)
# fields to copy
cfields = (:metadata, :symmetadata)
# normal fields
nfields = setdiff(fields, cfields)
assign = Expr(:block,
(:($(field) = c.$field) for field in nfields)...,
(:($(field) = deepcopy(c.$field)) for field in cfields)...)
construct = Expr(:call, c, [:($field) for field in fields]...)

Checks if a `default` or `init` value is present for symbol `sym`.
"""
has_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_init(c, sym)
"""
get_default_or_init(c::ComponentFunction, sym::Symbol)

Returns if a `default` value if available, otherwise returns `init` value for symbol `sym`.
"""
get_default_or_init(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_init(c, sym)

#### default or guess
"""
has_default_or_guess(c::ComponentFunction, sym::Symbol)

Checks if a `default` or `guess` value is present for symbol `sym`.
"""
has_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) || has_guess(c, sym)
"""
get_default_or_guess(c::ComponentFunction, sym::Symbol)

Returns if a `default` value if available, otherwise returns `guess` value for symbol `sym`.
"""
get_default_or_guess(c::ComponentFunction, sym::Symbol) = has_default(c, sym) ? get_default(c, sym) : get_guess(c, sym)


# TODO: legacy, only used within show methods
function def(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.sym) do s
has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing
end
end
function guess(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.sym) do s
has_guess(c, s) ? get_guess(c, s) : nothing
end
end
function pdef(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.psym) do s
has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing
end
end
function pguess(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.psym) do s
has_guess(c, s) ? get_guess(c, s) : nothing
quote
$assign
$construct
end
end

####
#### Component metadata
####
"""
has_metadata(c::ComponentFunction, key::Symbol)

Checks if metadata `key` is present for the component.
"""
function has_metadata(c::ComponentFunction, key)
haskey(metadata(c), key)
end
"""
get_metadata(c::ComponentFunction, key::Symbol)

Retrieves the metadata `key` for the component.
"""
get_metadata(c::ComponentFunction, key::Symbol) = metadata(c)[key]
"""
set_metadata!(c::ComponentFunction, key::Symbol, value)

Sets the metadata `key` for the component to `value`.
"""
set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), val, key)

#### graphelement field for edges and vertices
"""
has_graphelement(c)

Checks if the edge or vetex function function has the `graphelement` metadata.
"""
has_graphelement(c::ComponentFunction) = has_metadata(c, :graphelement)
"""
get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T}
get_graphelement(c::VertexFunction)::Int

Retrieves the `graphelement` metadata for the component function. For edges this
returns a named tupe `(;src, dst)` where both are either integers (vertex index)
or symbols (vertex name).
"""
get_graphelement(c::EdgeFunction) = get_metadata(c, :graphelement)::@NamedTuple{src::T, dst::T} where {T<:Union{Int,Symbol}}
get_graphelement(c::VertexFunction) = get_metadata(c, :graphelement)::Int
"""
set_graphelement!(c::EdgeFunction, src, dst)
set_graphelement!(c::VertexFunction, vidx)

Sets the `graphelement` metadata for the edge function. For edges this takes two
arguments `src` and `dst` which are either integer (vertex index) or symbol
(vertex name). For vertices it takes a single integer `vidx`.
"""
set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt)
set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx)


function get_defaults(c::ComponentFunction, syms)
[has_default(c, sym) ? get_default(c, sym) : nothing for sym in syms]
end
function get_guesses(c::ComponentFunction, syms)
[has_guess(c, sym) ? get_guess(c, sym) : nothing for sym in syms]
end
function get_defaults_or_inits(c::ComponentFunction, syms)
[has_default_or_init(c, sym) ? get_default_or_init(c, sym) : nothing for sym in syms]
Base.hash(cf::ComponentFunction, h::UInt) = hash_fields(cf, h)
function Base.:(==)(cf1::ComponentFunction, cf2::ComponentFunction)
typeof(cf1) == typeof(cf2) && equal_fields(cf1, cf2)
end
Loading
Loading