Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Handle Adjoints through Initialization #1168

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- Core5
- Core6
- Core7
- Core8
- QA
- SDE1
- SDE2
Expand Down
15 changes: 15 additions & 0 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,21 @@ function DiffEqBase._concrete_solve_adjoint(
save_end = true, kwargs_fwd...)
end

# Get gradients for the initialization problem if it exists
igs = if _prob.f.initialization_data.initializeprob != nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be before the solve, since you can use the initialization solution from here in the remakes of 397-405 in order to set new u0 and p and thus skip running the initialization a second time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can I indicate to solve to avoid running initialization?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initializealg = NoInit(). Should probably just do CheckInit() for safety but either is fine.

iprob = _prob.f.initialization_data.initializeprob
ip = parameter_values(iprob)
itunables, irepack, ialiases = canonicalize(Tunable(), ip)
igs, = Zygote.gradient(ip) do ip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gradient isn't used? I think this would go into the backpass and if I'm thinking clearly, the resulting return is dp .* igs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. These gradients are currently against the parameters of the initialization problem, not the system exactly. And the mapping between the two is ill defined, so we cannot simply accum

I spoke with @AayushSabharwal about a way to map, it seems initialization_data.intializeprobmap might have some support to return the correctly shaped vector, but there are cases where we cannot know the ordering of dp either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, you need to use the initializeprobmap https://github.com/SciML/SciMLBase.jl/blob/master/src/initialization.jl#L268 to map it back to the shape of the initial parameters.

but there are cases where we cannot know the ordering of dp either.

p and dp just need the same ordering, so initializeprobmap should do the trick.

There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually

This is the only change to (u0,p) before solving, so this would account for it, given initializeprobmap is just an index map so an identity function.

iprob2 = remake(iprob, p = ip)
sol = solve(iprob2)
sum(Array(sol))
end
igs
else
nothing
end

# Force `save_start` and `save_end` in the forward pass This forces the
# solver to do the backsolve all the way back to `u0` Since the start aliases
# `_prob.u0`, this doesn't actually use more memory But it cleans up the
Expand Down
5 changes: 3 additions & 2 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1068,9 +1068,10 @@ end

function build_param_jac_config(alg, pf, u, p)
if alg_autodiff(alg)
jac_config = ForwardDiff.JacobianConfig(pf, u, p,
tunables, repack, aliases = canonicalize(Tunable(), p)
jac_config = ForwardDiff.JacobianConfig(pf, u, tunables,
ForwardDiff.Chunk{
determine_chunksize(p,
determine_chunksize(tunables,
alg)}())
else
if diff_type(alg) != Val{:complex}
Expand Down
10 changes: 6 additions & 4 deletions src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ end
sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp,
f, f.colorvec, needs_jac)
(; diffcache, y, sol, λ, vjp, linsolve) = sense

if needs_jac
if SciMLBase.has_jac(f)
f.jac(diffcache.J, y, p, nothing)
else
if DiffEqBase.isinplace(sol.prob)
jacobian!(diffcache.J, diffcache.uf, y, diffcache.f_cache,
jacobian!(diffcache.J.du, diffcache.uf, y, diffcache.f_cache,
sensealg, diffcache.jac_config)
else
diffcache.J .= jacobian(diffcache.uf, y, sensealg)
Expand Down Expand Up @@ -103,15 +102,18 @@ end
else
if linsolve === nothing && isempty(sensealg.linsolve_kwargs)
# For the default case use `\` to avoid any form of unnecessary cache allocation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't know about that comment. I think it's just old. (a) \ always allocates because it uses lu instead of lu!, so it's re-allocating the while matrix which is larger than any LinearSolve allocation, and (b) we have since 2023 setup tests on StaticArrays, so the immutable path is non-allocating. I don't think (b) was true when this was written.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So glad we can remove this branch altogether.

vec(λ) .= diffcache.J' \ vec(dgdu_val)
linear_problem = LinearProblem(diffcache.J.du', vec(dgdu_val'); u0 = vec(λ))
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ)
else
linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ))
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ)
end
end

try
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing)
tunables, repack, aliases = canonicalize(Tunable(), p)
vjp_tunables, vjp_repack, vjp_aliases = canonicalize(Tunable(), vjp)
vecjacobian!(vec(dgdu_val), y, λ, tunables, nothing, sense; dgrad = vjp_tunables, dy = nothing)
catch e
if sense.sensealg.autojacvec === nothing
@warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp"
Expand Down
64 changes: 64 additions & 0 deletions test/desauty_dae_mwe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using ModelingToolkit, OrdinaryDiffEq
using ModelingToolkitStandardLibrary.Electrical
using ModelingToolkitStandardLibrary.Blocks: Sine
using NonlinearSolve
import SciMLStructures as SS
import SciMLSensitivity
using Zygote

function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@variables t
@named resistor1 = Resistor(R = 5.0)
@named resistor2 = Resistor(R = 2.0)
@named capacitor1 = Capacitor(C = C₁)
@named capacitor2 = Capacitor(C = C₂)
@named source = Voltage()
@named input_signal = Sine(frequency = 100.0)
@named ground = Ground()
@named ampermeter = CurrentSensor()

eqs = [connect(input_signal.output, source.V)
connect(source.p, capacitor1.n, capacitor2.n)
connect(source.n, resistor1.p, resistor2.p, ground.g)
connect(resistor1.n, capacitor1.p, ampermeter.n)
connect(resistor2.n, capacitor2.p, ampermeter.p)]

@named circuit_model = ODESystem(eqs, t,
systems = [
resistor1, resistor2, capacitor1, capacitor2,
source, input_signal, ground, ampermeter,
])
end

desauty_model = create_model()
sys = structural_simplify(desauty_model)


prob = ODEProblem(sys, [], (0.0, 0.1), guesses = [sys.resistor1.v => 1.])
iprob = prob.f.initialization_data.initializeprob
isys = iprob.f.sys

tunables, repack, aliases = SS.canonicalize(SS.Tunable(), parameter_values(iprob))

linsolve = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization)
sensealg = SciMLSensitivity.SteadyStateAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP(), linsolve = linsolve)
igs, = Zygote.gradient(tunables) do p
iprob2 = remake(iprob, p = repack(p))
sol = solve(iprob2,
sensealg = sensealg
)
sum(Array(sol))
end

@test !iszero(sum(igs))


# tunable_parameters(isys) .=> gs

# gradient_unk1_idx = only(findfirst(x -> isequal(x, Initial(sys.capacitor1.v)), tunable_parameters(isys)))

# gs[gradient_unk1_idx]

# prob.f.initialization_data.update_initializeprob!(iprob, prob)
# prob.f.initialization_data.update_initializeprob!(iprob, ::Vector)
# prob.f.initialization_data.update_initializeprob!(iprob, gs)
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ end
end
end

if GROUP == "All" || GROUP == "Core8"
@testset "Core 8" begin
@time @safetestset "Initialization with MTK" include("desauty_dae_mwe.jl")
end
end

if GROUP == "All" || GROUP == "QA"
@time @safetestset "Quality Assurance" include("aqua.jl")
end
Expand Down
Loading