diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 622cb9b4b..545a1f06b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -24,6 +24,7 @@ jobs: - Core5 - Core6 - Core7 + - Core8 - QA - SDE1 - SDE2 diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index aeb85ad6f..b6adcc6b9 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -425,6 +425,21 @@ function DiffEqBase._concrete_solve_adjoint( save_end = true, kwargs_fwd...) end + # Get gradients for the initialization problem if it exists + igs = if _prob.f.initialization_data.initializeprob != nothing + iprob = _prob.f.initialization_data.initializeprob + ip = parameter_values(iprob) + itunables, irepack, ialiases = canonicalize(Tunable(), ip) + igs, = Zygote.gradient(ip) do ip + iprob2 = remake(iprob, p = ip) + sol = solve(iprob2) + sum(Array(sol)) + end + igs + else + nothing + end + # Force `save_start` and `save_end` in the forward pass This forces the # solver to do the backsolve all the way back to `u0` Since the start aliases # `_prob.u0`, this doesn't actually use more memory But it cleans up the diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 2b1f48849..938264891 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -1068,9 +1068,10 @@ end function build_param_jac_config(alg, pf, u, p) if alg_autodiff(alg) - jac_config = ForwardDiff.JacobianConfig(pf, u, p, + tunables, repack, aliases = canonicalize(Tunable(), p) + jac_config = ForwardDiff.JacobianConfig(pf, u, tunables, ForwardDiff.Chunk{ - determine_chunksize(p, + determine_chunksize(tunables, alg)}()) else if diff_type(alg) != Val{:complex} diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 01820360d..dab8e83a9 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -50,13 +50,12 @@ end sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, f, f.colorvec, needs_jac) (; diffcache, y, sol, λ, vjp, linsolve) = sense - if needs_jac if SciMLBase.has_jac(f) f.jac(diffcache.J, y, p, nothing) else if DiffEqBase.isinplace(sol.prob) - jacobian!(diffcache.J, diffcache.uf, y, diffcache.f_cache, + jacobian!(diffcache.J.du, diffcache.uf, y, diffcache.f_cache, sensealg, diffcache.jac_config) else diffcache.J .= jacobian(diffcache.uf, y, sensealg) @@ -103,7 +102,8 @@ end else if linsolve === nothing && isempty(sensealg.linsolve_kwargs) # For the default case use `\` to avoid any form of unnecessary cache allocation - vec(λ) .= diffcache.J' \ vec(dgdu_val) + linear_problem = LinearProblem(diffcache.J.du', vec(dgdu_val'); u0 = vec(λ)) + solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ) else linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ)) solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ) @@ -111,7 +111,9 @@ end end try - vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing) + tunables, repack, aliases = canonicalize(Tunable(), p) + vjp_tunables, vjp_repack, vjp_aliases = canonicalize(Tunable(), vjp) + vecjacobian!(vec(dgdu_val), y, λ, tunables, nothing, sense; dgrad = vjp_tunables, dy = nothing) catch e if sense.sensealg.autojacvec === nothing @warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp" diff --git a/test/desauty_dae_mwe.jl b/test/desauty_dae_mwe.jl new file mode 100644 index 000000000..cc00080fa --- /dev/null +++ b/test/desauty_dae_mwe.jl @@ -0,0 +1,64 @@ +using ModelingToolkit, OrdinaryDiffEq +using ModelingToolkitStandardLibrary.Electrical +using ModelingToolkitStandardLibrary.Blocks: Sine +using NonlinearSolve +import SciMLStructures as SS +import SciMLSensitivity +using Zygote + +function create_model(; C₁ = 3e-5, C₂ = 1e-6) + @variables t + @named resistor1 = Resistor(R = 5.0) + @named resistor2 = Resistor(R = 2.0) + @named capacitor1 = Capacitor(C = C₁) + @named capacitor2 = Capacitor(C = C₂) + @named source = Voltage() + @named input_signal = Sine(frequency = 100.0) + @named ground = Ground() + @named ampermeter = CurrentSensor() + + eqs = [connect(input_signal.output, source.V) + connect(source.p, capacitor1.n, capacitor2.n) + connect(source.n, resistor1.p, resistor2.p, ground.g) + connect(resistor1.n, capacitor1.p, ampermeter.n) + connect(resistor2.n, capacitor2.p, ampermeter.p)] + + @named circuit_model = ODESystem(eqs, t, + systems = [ + resistor1, resistor2, capacitor1, capacitor2, + source, input_signal, ground, ampermeter, + ]) +end + +desauty_model = create_model() +sys = structural_simplify(desauty_model) + + +prob = ODEProblem(sys, [], (0.0, 0.1), guesses = [sys.resistor1.v => 1.]) +iprob = prob.f.initialization_data.initializeprob +isys = iprob.f.sys + +tunables, repack, aliases = SS.canonicalize(SS.Tunable(), parameter_values(iprob)) + +linsolve = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization) +sensealg = SciMLSensitivity.SteadyStateAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP(), linsolve = linsolve) +igs, = Zygote.gradient(tunables) do p + iprob2 = remake(iprob, p = repack(p)) + sol = solve(iprob2, + sensealg = sensealg + ) + sum(Array(sol)) +end + +@test !iszero(sum(igs)) + + +# tunable_parameters(isys) .=> gs + +# gradient_unk1_idx = only(findfirst(x -> isequal(x, Initial(sys.capacitor1.v)), tunable_parameters(isys))) + +# gs[gradient_unk1_idx] + +# prob.f.initialization_data.update_initializeprob!(iprob, prob) +# prob.f.initialization_data.update_initializeprob!(iprob, ::Vector) +# prob.f.initialization_data.update_initializeprob!(iprob, gs) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 3e9af1197..05eac46c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -102,6 +102,12 @@ end end end + if GROUP == "All" || GROUP == "Core8" + @testset "Core 8" begin + @time @safetestset "Initialization with MTK" include("desauty_dae_mwe.jl") + end + end + if GROUP == "All" || GROUP == "QA" @time @safetestset "Quality Assurance" include("aqua.jl") end