From 8fcd42d475844b7319d16fbb3bfebc104aa005af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 7 Oct 2024 08:19:19 +0200 Subject: [PATCH] fix construction with partial graphelement information --- src/component_functions.jl | 3 +-- src/construction.jl | 3 ++- src/initialization.jl | 28 +++++++++++++++++----------- test/construction_test.jl | 11 ++++++++++- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/component_functions.jl b/src/component_functions.jl index d13ee51f..134cad71 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -760,8 +760,7 @@ set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), v Checks if the edge or vetex function function has the `graphelement` metadata. """ -has_graphelement(c::EdgeFunction) = has_metadata(c, :graphelement) -has_graphelement(c::VertexFunction) = has_metadata(c, :graphelement) +has_graphelement(c::ComponentFunction) = has_metadata(c, :graphelement) """ get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T} get_graphelement(c::VertexFunction)::Int diff --git a/src/construction.jl b/src/construction.jl index fc2a0cb8..38287efd 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -160,7 +160,7 @@ end function _unique_name_dict(cfs::AbstractVector{<:ComponentFunction}) # find all names to resolve names = getproperty.(cfs, :name) - dict = Dict(names .=> get_graphelement.(cfs)) + dict = Dict(cf.name => get_graphelement(cf) for cf in cfs if has_graphelement(cf)) # delete all names which occure multiple times for i in eachindex(names) if names[i] ∈ @views names[i+1:end] @@ -344,6 +344,7 @@ end Network(nw::Network; kwargs...) Rebuild the Network with same graph and vertex/edge functions but possibly different kwargs. +# FIXME : needs to take all Network kw arguments into acount! """ function Network(nw::Network; kwargs...) Network(nw.im.g, nw.im.vertexf, nw.im.edgef; kwargs...) diff --git a/src/initialization.jl b/src/initialization.jl index 3a70cc48..718c1f67 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -148,19 +148,25 @@ in the metadata field `init`. The `kwargs` are passed to the nonlinear solver. """ function initialize_component!(cf; verbose=true, kwargs...) - prob = initialization_problem(cf) - sol = SciMLBase.solve(prob; kwargs...) - - if sol.prob isa NonlinearLeastSquaresProblem && sol.retcode == SciMLBase.ReturnCode.Stalled - # https://github.com/SciML/NonlinearSolve.jl/issues/459 - res = LinearAlgebra.norm(sol.resid) - @warn "Initialization for componend stalled with residual $(res)" - elseif !SciMLBase.successful_retcode(sol.retcode) - throw(ArgumentError("Initialization failed. Solver returned $(sol.retcode)")) + prob = initialization_problem(cf; verbose) + + if !isempty(prob.u0) + sol = SciMLBase.solve(prob; kwargs...) + + if sol.prob isa NonlinearLeastSquaresProblem && sol.retcode == SciMLBase.ReturnCode.Stalled + # https://github.com/SciML/NonlinearSolve.jl/issues/459 + res = LinearAlgebra.norm(sol.resid) + @warn "Initialization for componend stalled with residual $(res)" + elseif !SciMLBase.successful_retcode(sol.retcode) + throw(ArgumentError("Initialization failed. Solver returned $(sol.retcode)")) + end + set_init!.(Ref(cf), SII.variable_symbols(sol), sol.u) + resid = sol.resid + else + resid = init_residual(cf; recalc=true) end - set_init!.(Ref(cf), SII.variable_symbols(sol), sol.u) - set_metadata!(cf, :init_residual, sol.resid) + set_metadata!(cf, :init_residual, resid) verbose && @info "Initialization successful with residual $(LinearAlgebra.norm(sol.resid))" cf diff --git a/test/construction_test.jl b/test/construction_test.jl index 0bc99d93..d534fb7f 100644 --- a/test/construction_test.jl +++ b/test/construction_test.jl @@ -72,8 +72,17 @@ end set_graphelement!(e3, (;src=2,dst=1)) Network([v1,v2,v3], [e1,e2,e3]) # throws waring about 1->2 and 2->1 beeing present -end + v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) + v2 = ODEVertex(x->x^2, 2, 0; name=:v2) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3) + @test NetworkDynamics._unique_name_dict([v1,v2,v3]) == Dict(:v1=>1) + + v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1) + v2 = ODEVertex(x->x^2, 2, 0; name=:v1) + v3 = ODEVertex(x->x^3, 2, 0; name=:v3) + @test NetworkDynamics._unique_name_dict([v1,v2,v3]) == Dict() +end @testset "Vertex batch" begin using NetworkDynamics: BatchStride, VertexBatch, parameter_range vb = VertexBatch{ODEVertex, typeof(sum), Vector{Int}}([1, 2, 3, 4], # vertices