Skip to content

Commit

Permalink
Merge pull request #162 from JuliaDynamics/hw/show
Browse files Browse the repository at this point in the history
minor improvmenets
  • Loading branch information
hexaeder authored Oct 17, 2024
2 parents 3858f6d + 22fe159 commit 3e99370
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 27 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand Down Expand Up @@ -57,6 +58,7 @@ NonlinearSolve = "3.13.0"
Polyester = "0.7.12"
PreallocationTools = "0.4.23"
PrecompileTools = "1.2.1"
Printf = "1.10.0"
Random = "1"
RecursiveArrayTools = "3.27.0"
SciMLBase = "2"
Expand Down
53 changes: 46 additions & 7 deletions ext/MTKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ using ModelingToolkit.Symbolics: Symbolics, fixpoint_sub, substitute
using ArgCheck: @argcheck
using LinearAlgebra: Diagonal, I

using NetworkDynamics: Fiducial
import NetworkDynamics: ODEVertex, StaticEdge
using NetworkDynamics: Coupling, Fiducial, set_metadata!
import NetworkDynamics: ODEVertex, StaticEdge, ODEEdge

include("MTKUtils.jl")

function ODEVertex(sys::ODESystem, inputs, outputs; verbose=false, name=getname(sys))
function ODEVertex(sys::ODESystem, inputs, outputs; verbose=false, name=getname(sys), kwargs...)
warn_events(sys)
inputs = inputs isa AbstractVector ? inputs : [inputs]
outputs = outputs isa AbstractVector ? outputs : [outputs]
Expand All @@ -37,10 +37,47 @@ function ODEVertex(sys::ODESystem, inputs, outputs; verbose=false, name=getname(
obsf = gen.g_ip

mass_matrix = gen.mass_matrix
ODEVertex(;f, sym, psym, depth, inputsym, obssym, obsf, mass_matrix, name)
v = ODEVertex(;f, sym, psym, depth, inputsym, obssym, obsf, mass_matrix, name, kwargs...)
set_metadata!(v, :observed, gen.observed)
set_metadata!(v, :equations, gen.equations)
v
end

function StaticEdge(sys::ODESystem, srcin, dstin, outputs, coupling; verbose=false, name=getname(sys))
function ODEEdge(sys::ODESystem, srcin, dstin, outputs, coupling::Coupling; verbose=false, name=getname(sys), kwargs...)
warn_events(sys)
srcin = srcin isa AbstractVector ? srcin : [srcin]
dstin = dstin isa AbstractVector ? dstin : [dstin]
outputs = outputs isa AbstractVector ? outputs : [outputs]
gen = generate_io_function(sys, (srcin, dstin), outputs; type=:ode, verbose)

f = gen.f_ip

_sym = getname.(gen.states)
sym = [s => _get_metadata(sys, s) for s in _sym]

_psym = getname.(gen.params)
psym = [s => _get_metadata(sys, s) for s in _psym]

_obssym = getname.(gen.obsstates)
obssym = [s => _get_metadata(sys, s) for s in _obssym]

_inputsym_src = getname.(srcin)
inputsym_src = [s => _get_metadata(sys, s) for s in _inputsym_src]

_inputsym_dst = getname.(dstin)
inputsym_dst = [s => _get_metadata(sys, s) for s in _inputsym_dst]

depth = coupling isa Fiducial ? Int(length(outputs)/2) : length(outputs)
obsf = gen.g_ip

mass_matrix = gen.mass_matrix
e = ODEEdge(;f, sym, psym, depth, inputsym_src, inputsym_dst, obssym, obsf, coupling, mass_matrix, name, kwargs...)
set_metadata!(e, :observed, gen.observed)
set_metadata!(e, :equations, gen.equations)
e
end

function StaticEdge(sys::ODESystem, srcin, dstin, outputs, coupling::Coupling; verbose=false, name=getname(sys), kwargs...)
warn_events(sys)
srcin = srcin isa AbstractVector ? srcin : [srcin]
dstin = dstin isa AbstractVector ? dstin : [dstin]
Expand All @@ -67,7 +104,7 @@ function StaticEdge(sys::ODESystem, srcin, dstin, outputs, coupling; verbose=fal
depth = coupling isa Fiducial ? Int(length(outputs)/2) : length(outputs)
obsf = gen.g_ip

StaticEdge(;f, sym, psym, depth, inputsym_src, inputsym_dst, obssym, obsf, coupling, name)
StaticEdge(;f, sym, psym, depth, inputsym_src, inputsym_dst, obssym, obsf, coupling, name, kwargs...)
end

"""
Expand Down Expand Up @@ -151,7 +188,7 @@ function generate_io_function(_sys, inputss::Tuple, outputs;
fix_metadata!(obseqs, sys);
# obs can only depend on parameters (including allinputs) or states
obs_deps = mapreduce(eq -> get_variables(eq.rhs), union, obseqs, init=Symbolic[])
if !(obs_deps Set(_params) Set(_states))
if !(obs_deps Set(_params) Set(_states) independent_variables(sys))
@warn "obs_deps !⊆ parameters ∪ unknowns. Difference: $(setdiff(obs_deps, Set(_params) Set(_states)))"
end

Expand Down Expand Up @@ -272,6 +309,8 @@ function generate_io_function(_sys, inputss::Tuple, outputs;
inputss,
obsstates,
g_oop, g_ip,
equations=formulas,
observed=Dict(obsstates .=> obsformulas),
params)
end

Expand Down
4 changes: 3 additions & 1 deletion src/NetworkDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using StyledStrings: StyledStrings, @styled_str
using RecursiveArrayTools: DiffEqArray
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using Printf: @sprintf

@static if VERSION v"1.11.0-0"
using Base: AnnotatedIOBuffer, AnnotatedString
Expand Down Expand Up @@ -82,7 +83,8 @@ println(s1, "\n", s2)
=#
const ND_FACES = [
:NetworkDynamics_inactive => StyledStrings.Face(foreground=:bright_black),
:NetworkDynamics_defaultval => StyledStrings.Face(foreground=:bright_black),
:NetworkDynamics_defaultval => StyledStrings.Face(weight=:light),
:NetworkDynamics_guessval => StyledStrings.Face(foreground=:bright_black, weight=:light),
:NetworkDynamics_fordstsrc => StyledStrings.Face(foreground=:bright_blue),
:NetworkDynamics_fordst => StyledStrings.Face(foreground=:bright_yellow),
:NetworkDynamics_forsrc => StyledStrings.Face(foreground=:bright_magenta),
Expand Down
30 changes: 30 additions & 0 deletions src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,15 @@ function _fill_defaults(T, kwargs)
ge = pop!(dict, :graphelement)
metadata[:graphelement] = ge
end
if haskey(dict, :vidx) && T <: VertexFunction
vidx = pop!(dict, :vidx)
metadata[:graphelement] = vidx
end
if haskey(dict, :src) && haskey(dict, :dst) && T <: EdgeFunction
src = pop!(dict, :src)
dst = pop!(dict, :dst)
metadata[:graphelement] = (; src, dst)
end

# sym & dim
haskey(dict, :dim) || haskey(dict, :sym) || throw(ArgumentError("Either `dim` or `sym` must be provided to construct $T."))
Expand Down Expand Up @@ -724,11 +733,21 @@ function def(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing
end
end
function guess(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.sym) do s
has_guess(c, s) ? get_guess(c, s) : nothing
end
end
function pdef(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.psym) do s
has_default_or_init(c, s) ? get_default_or_init(c, s) : nothing
end
end
function pguess(c::ComponentFunction)::Vector{Union{Nothing,Float64}}
map(c.psym) do s
has_guess(c, s) ? get_guess(c, s) : nothing
end
end

####
#### Component metadata
Expand Down Expand Up @@ -781,3 +800,14 @@ arguments `src` and `dst` which are either integer (vertex index) or symbol
"""
set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt)
set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx)


function get_defaults(c::ComponentFunction, syms)
[has_default(c, sym) ? get_default(c, sym) : nothing for sym in syms]
end
function get_guesses(c::ComponentFunction, syms)
[has_guess(c, sym) ? get_guess(c, sym) : nothing for sym in syms]
end
function get_defaults_or_inits(c::ComponentFunction, syms)
[has_default_or_init(c, sym) ? get_default_or_init(c, sym) : nothing for sym in syms]
end
31 changes: 17 additions & 14 deletions src/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ function Network(g::AbstractGraph,
edepth=:auto,
vdepth=:auto,
aggregator=execution isa SequentialExecution ? SequentialAggregator(+) : PolyesterAggregator(+),
check_graphelement=true,
verbose=false)
reset_timer!()
@timeit_debug "Construct Network" begin
Expand All @@ -17,22 +18,24 @@ function Network(g::AbstractGraph,
@argcheck length(_edgef) == ne(g)

# check if graphelement is set correctly, warn otherwise
for (i, v) in pairs(_vertexf)
if has_graphelement(v)
if get_graphelement(v) != i
@warn "Vertex function $v has wrong `:graphelement` $(get_graphelement(v)) != $i. \
Using this constructor the provided `:graphelement` is ignored!"
if check_graphelement
for (i, v) in pairs(_vertexf)
if has_graphelement(v)
if get_graphelement(v) != i
@warn "Vertex function $v has wrong `:graphelement` $(get_graphelement(v)) != $i. \
Using this constructor the provided `:graphelement` is ignored!"
end
end
end
end
if any(has_graphelement, _edgef)
vnamedict = _unique_name_dict(_vertexf)
for (iteredge, ef) in zip(edges(g), _edgef)
if has_graphelement(ef)
ge = get_graphelement(ef)
if iteredge != _resolve_ge_to_edge(ge, vnamedict)
@warn "Edge function $ef has wrong `:graphelement` $(get_graphelement(ef)) != $iteredge. \
Using this constructor the provided `:graphelement` is ignored!"
if any(has_graphelement, _edgef)
vnamedict = _unique_name_dict(_vertexf)
for (iteredge, ef) in zip(edges(g), _edgef)
if has_graphelement(ef)
ge = get_graphelement(ef)
if iteredge != _resolve_ge_to_edge(ge, vnamedict)
@warn "Edge function $ef has wrong `:graphelement` $(get_graphelement(ef)) != $iteredge. \
Using this constructor the provided `:graphelement` is ignored!"
end
end
end
end
Expand Down
51 changes: 46 additions & 5 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ Base.show(io::IO, s::PolyesterAggregator) = print(io, "PolyesterAggregator($(rep
function Base.show(io::IO, ::MIME"text/plain", c::EdgeFunction)
type = match(r"^(.*?)\{", string(typeof(c)))[1]
print(io, type, styled" :$(c.name) with $(_styled_coupling(coupling(c))) coupling of depth {NetworkDynamics_fordst:$(depth(c))}")
if has_graphelement(c)
ge = get_graphelement(c)
print(io, " @ Edge $(ge.src) => $(ge.dst)")
end

styling = Dict{Int,Symbol}()
if coupling(c) == Fiducial()
Expand Down Expand Up @@ -56,6 +60,9 @@ _styled_coupling(::Symmetric) = styled"{NetworkDynamics_fordstsrc:Symmetric}"
function Base.show(io::IO, ::MIME"text/plain", c::VertexFunction)
type = match(r"^(.*?)\{", string(typeof(c)))[1]
print(io, type, styled" :$(c.name) with depth {NetworkDynamics_forlayer:$(depth(c))}")
if has_graphelement(c)
print(io, " @ Vertex $(get_graphelement(c))")
end

styling = Dict{Int,Symbol}()
for i in 1:depth(c)
Expand All @@ -68,7 +75,7 @@ end
function print_states_params(io, c::ComponentFunction, styling)
info = AnnotatedString{String}[]
num, word = maybe_plural(dim(c), "state")
push!(info, styled"$num &$word: &&$(stylesymbolarray(c.sym, def(c), styling))")
push!(info, styled"$num &$word: &&$(stylesymbolarray(c.sym, def(c), guess(c), styling))")

if hasproperty(c, :mass_matrix) && c.mass_matrix != LinearAlgebra.I
if LinearAlgebra.isdiag(c.mass_matrix) && !(c.mass_matrix isa UniformScaling)
Expand All @@ -79,19 +86,40 @@ function print_states_params(io, c::ComponentFunction, styling)
end

num, word = maybe_plural(pdim(c), "param")
pdim(c) > 0 && push!(info, styled"$num &$word: &&$(stylesymbolarray(c.psym, pdef(c)))")
pdim(c) > 0 && push!(info, styled"$num &$word: &&$(stylesymbolarray(c.psym, pdef(c), pguess(c)))")

if !isnothing(c.inputsym)
if c isa VertexFunction
_, word = maybe_plural(length(c.inputsym), "input")
defs = get_defaults_or_inits(c, c.inputsym)
guesses = get_guesses(c, c.inputsym)
push!(info, styled"&$word: &&$(stylesymbolarray(c.inputsym, defs, guesses))")
elseif c isa EdgeFunction
_, word = maybe_plural(length(c.inputsym.src), "input")
srcdefs = get_defaults_or_inits(c, c.inputsym.src)
dstdefs = get_defaults_or_inits(c, c.inputsym.dst)
srcguesses = get_guesses(c, c.inputsym.src)
dstguesses = get_guesses(c, c.inputsym.dst)
push!(info, styled"src&$word: &&$(stylesymbolarray(c.inputsym.src, srcdefs, srcguesses))\n\
dst&$word: &&$(stylesymbolarray(c.inputsym.dst, srcdefs, srcguesses))")
end
end

print_treelike(io, align_strings(info))
end

function stylesymbolarray(syms, defaults, symstyles=Dict{Int,Symbol}())
function stylesymbolarray(syms, defaults, guesses, symstyles=Dict{Int,Symbol}())
@assert length(syms) == length(defaults)
ret = "["
for (i, sym, default) in zip(1:length(syms), syms, defaults)
for (i, sym, default, guess) in zip(1:length(syms), syms, defaults, guesses)
style = get(symstyles, i, :default)
ret = ret * styled"{$style:$(string(sym))}"
if !isnothing(default)
ret = ret * styled"{NetworkDynamics_defaultval:=$default}"
_str = str_significant(default; sigdigits=2)
ret = ret * styled"{NetworkDynamics_defaultval:=$(_str)}"
elseif !isnothing(guess)
_str = str_significant(guess; sigdigits=2)
ret = ret * styled"{NetworkDynamics_guessval:≈$(_str)}"
end
if i < length(syms)
ret = ret * ", "
Expand Down Expand Up @@ -291,3 +319,16 @@ function maybe_plural(num, word, substitution=s"\1s")
end
num, word
end

function str_significant(x; sigdigits)
(x == 0) && (return "0")
x = round(x; sigdigits)
n = length(@sprintf("%d", abs(x))) # length of the integer part
if (x -1 || x 1)
decimals = max(sigdigits - n, 0) # 'sig - n' decimals needed
else
Nzeros = ceil(Int, -log10(abs(x))) - 1 # No. zeros after decimal point before first number
decimals = sigdigits + Nzeros
end
return @sprintf("%.*f", decimals, x)
end

0 comments on commit 3e99370

Please sign in to comment.