Skip to content

Commit

Permalink
fix construction with partial graphelement information
Browse files Browse the repository at this point in the history
  • Loading branch information
hexaeder committed Oct 7, 2024
1 parent 0e41b2f commit 8fcd42d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
3 changes: 1 addition & 2 deletions src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,7 @@ set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), v
Checks if the edge or vetex function function has the `graphelement` metadata.
"""
has_graphelement(c::EdgeFunction) = has_metadata(c, :graphelement)
has_graphelement(c::VertexFunction) = has_metadata(c, :graphelement)
has_graphelement(c::ComponentFunction) = has_metadata(c, :graphelement)
"""
get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T}
get_graphelement(c::VertexFunction)::Int
Expand Down
3 changes: 2 additions & 1 deletion src/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ end
function _unique_name_dict(cfs::AbstractVector{<:ComponentFunction})
# find all names to resolve
names = getproperty.(cfs, :name)
dict = Dict(names .=> get_graphelement.(cfs))
dict = Dict(cf.name => get_graphelement(cf) for cf in cfs if has_graphelement(cf))
# delete all names which occure multiple times
for i in eachindex(names)
if names[i] @views names[i+1:end]
Expand Down Expand Up @@ -344,6 +344,7 @@ end
Network(nw::Network; kwargs...)
Rebuild the Network with same graph and vertex/edge functions but possibly different kwargs.
# FIXME : needs to take all Network kw arguments into acount!
"""
function Network(nw::Network; kwargs...)
Network(nw.im.g, nw.im.vertexf, nw.im.edgef; kwargs...)
Expand Down
28 changes: 17 additions & 11 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,25 @@ in the metadata field `init`.
The `kwargs` are passed to the nonlinear solver.
"""
function initialize_component!(cf; verbose=true, kwargs...)
prob = initialization_problem(cf)
sol = SciMLBase.solve(prob; kwargs...)

if sol.prob isa NonlinearLeastSquaresProblem && sol.retcode == SciMLBase.ReturnCode.Stalled
# https://github.com/SciML/NonlinearSolve.jl/issues/459
res = LinearAlgebra.norm(sol.resid)
@warn "Initialization for componend stalled with residual $(res)"
elseif !SciMLBase.successful_retcode(sol.retcode)
throw(ArgumentError("Initialization failed. Solver returned $(sol.retcode)"))
prob = initialization_problem(cf; verbose)

if !isempty(prob.u0)
sol = SciMLBase.solve(prob; kwargs...)

if sol.prob isa NonlinearLeastSquaresProblem && sol.retcode == SciMLBase.ReturnCode.Stalled
# https://github.com/SciML/NonlinearSolve.jl/issues/459
res = LinearAlgebra.norm(sol.resid)
@warn "Initialization for componend stalled with residual $(res)"
elseif !SciMLBase.successful_retcode(sol.retcode)
throw(ArgumentError("Initialization failed. Solver returned $(sol.retcode)"))
end
set_init!.(Ref(cf), SII.variable_symbols(sol), sol.u)
resid = sol.resid
else
resid = init_residual(cf; recalc=true)
end
set_init!.(Ref(cf), SII.variable_symbols(sol), sol.u)

set_metadata!(cf, :init_residual, sol.resid)
set_metadata!(cf, :init_residual, resid)

verbose && @info "Initialization successful with residual $(LinearAlgebra.norm(sol.resid))"
cf
Expand Down
11 changes: 10 additions & 1 deletion test/construction_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,17 @@ end

set_graphelement!(e3, (;src=2,dst=1))
Network([v1,v2,v3], [e1,e2,e3]) # throws waring about 1->2 and 2->1 beeing present
end

v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1)
v2 = ODEVertex(x->x^2, 2, 0; name=:v2)
v3 = ODEVertex(x->x^3, 2, 0; name=:v3)
@test NetworkDynamics._unique_name_dict([v1,v2,v3]) == Dict(:v1=>1)

v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1)
v2 = ODEVertex(x->x^2, 2, 0; name=:v1)
v3 = ODEVertex(x->x^3, 2, 0; name=:v3)
@test NetworkDynamics._unique_name_dict([v1,v2,v3]) == Dict()
end
@testset "Vertex batch" begin
using NetworkDynamics: BatchStride, VertexBatch, parameter_range
vb = VertexBatch{ODEVertex, typeof(sum), Vector{Int}}([1, 2, 3, 4], # vertices
Expand Down

0 comments on commit 8fcd42d

Please sign in to comment.