Skip to content

Commit

Permalink
Merge pull request #100 from SciML/optjlintegration
Browse files Browse the repository at this point in the history
Changes for getting SciML/Optimization.jl#789 passing
  • Loading branch information
Vaibhavdixit02 authored Sep 10, 2024
2 parents 8f0a067 + f0a527b commit 9f59f80
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 123 deletions.
59 changes: 29 additions & 30 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if cons !== nothing && cons_j == true && f.cons_j === nothing
if num_cons > length(x)
seeds = Enzyme.onehot(x)
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
else
seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
Jaccache = Tuple(zero(x) for i in 1:num_cons)
end
# if num_cons > length(x)
seeds = Enzyme.onehot(x)
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
# else
# seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
# Jaccache = Tuple(zero(x) for i in 1:num_cons)
# end

y = zeros(eltype(x), num_cons)

Expand All @@ -219,27 +219,26 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(Jaccache[i])
end
Enzyme.make_zero!(y)
if num_cons > length(θ)
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p))
for i in eachindex(θ)
if J isa Vector
J[i] = Jaccache[i][1]
else
copyto!(@view(J[:, i]), Jaccache[i])
end
end
else
Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
BatchDuplicated(θ, Jaccache), Const(p))
for i in 1:num_cons
if J isa Vector
J .= Jaccache[1]
else
copyto!(@view(J[i, :]), Jaccache[i])
end
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p))
for i in eachindex(θ)
if J isa Vector
J[i] = Jaccache[i][1]
else
copyto!(@view(J[:, i]), Jaccache[i])
end
end
# else
# Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
# BatchDuplicated(θ, Jaccache), Const(p))
# for i in 1:num_cons
# if J isa Vector
# J .= Jaccache[1]
# else
# J[i, :] = Jaccache[i]
# end
# end
# end
end
elseif cons_j == true && cons !== nothing
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
Expand Down Expand Up @@ -397,11 +396,11 @@ end
function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
cache::OptimizationBase.ReInitCache,
adtype::AutoEnzyme,
num_cons = 0)
num_cons = 0; kwargs...)
p = cache.p
x = cache.u0

return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons)
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
end

function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x,
Expand Down Expand Up @@ -676,11 +675,11 @@ end
function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
cache::OptimizationBase.ReInitCache,
adtype::AutoEnzyme,
num_cons = 0)
num_cons = 0; kwargs...)
p = cache.p
x = cache.u0

return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons)
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
end

end
6 changes: 3 additions & 3 deletions ext/OptimizationMTKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0,
adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
Expand Down Expand Up @@ -107,7 +107,7 @@ end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p,
num_cons = 0, g = false, h = false, hv = false, fg = false, fgh = false,
num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
p = isnothing(p) ? SciMLBase.NullParameters() : p
Expand Down Expand Up @@ -155,7 +155,7 @@ end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::AutoSymbolics, num_cons = 0,
adtype::AutoSymbolics, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
Expand Down
51 changes: 24 additions & 27 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ function OptimizationBase.instantiate_function(
if f.lag_h === nothing && cons !== nothing && lag_h == true
lag_extras = prepare_hessian(
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
lag_hess_prototype = zeros(Bool, length(x), length(x))
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)

function lag_h!(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
Expand All @@ -232,13 +232,11 @@ function OptimizationBase.instantiate_function(
end
end

function lag_h!(h, θ, σ, λ)
H = eltype(θ).(lag_hess_prototype)
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
function lag_h!(h::AbstractVector, θ, σ, λ)
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
k = 0
rows, cols, _ = findnz(H)
for (i, j) in zip(rows, cols)
if i <= j
for i in 1:length(θ)
for j in 1:i
k += 1
h[k] = H[i, j]
end
Expand All @@ -256,7 +254,7 @@ function OptimizationBase.instantiate_function(
1:length(θ), 1:length(θ)])
end
end

function lag_h!(h::AbstractVector, θ, σ, λ, p)
global _p = p
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
Expand Down Expand Up @@ -294,21 +292,20 @@ end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AutoZygote, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false)
adtype::ADTypes.AutoZygote, num_cons = 0; kwargs...)
x = cache.u0
p = cache.p

return OptimizationBase.instantiate_function(
f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h)
f, x, adtype, p, num_cons; kwargs...)
end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote},
p = SciMLBase.NullParameters(), num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false)
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
function _f(θ)
return f.f(θ, p)[1]
end
Expand All @@ -335,7 +332,7 @@ function OptimizationBase.instantiate_function(
grad = nothing
end

if fg == true && f.fg !== nothing
if fg == true && f.fg === nothing
if g == false
extras_grad = prepare_gradient(_f, adtype.dense_ad, x)
end
Expand All @@ -361,7 +358,7 @@ function OptimizationBase.instantiate_function(

hess_sparsity = f.hess_prototype
hess_colors = f.hess_colorvec
if f.hess === nothing
if h == true && f.hess === nothing
extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better
function hess(res, θ)
hessian!(_f, res, soadtype, θ, extras_hess)
Expand All @@ -384,7 +381,7 @@ function OptimizationBase.instantiate_function(
hess = nothing
end

if fgh == true && f.fgh !== nothing
if fgh == true && f.fgh === nothing
function fgh!(G, H, θ)
(y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess)
return y
Expand All @@ -406,7 +403,7 @@ function OptimizationBase.instantiate_function(
fgh! = nothing
end

if hv == true && f.hv !== nothing
if hv == true && f.hv === nothing
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x)))
function hv!(H, θ, v)
hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp)
Expand Down Expand Up @@ -443,7 +440,7 @@ function OptimizationBase.instantiate_function(
θ = augvars[1:length(x)]
σ = augvars[length(x) + 1]
λ = augvars[(length(x) + 2):end]
return σ * _f(θ) + dot(λ, cons(θ))
return σ * _f(θ) + dot(λ, cons_oop(θ))
end
end

Expand All @@ -466,7 +463,8 @@ function OptimizationBase.instantiate_function(
end

if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
extras_pullback = prepare_pullback(cons_oop, adtype, x)
extras_pullback = prepare_pullback(
cons_oop, adtype.dense_ad, x, ones(eltype(x), num_cons))
function cons_vjp!(J, θ, v)
pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback)
end
Expand All @@ -477,7 +475,8 @@ function OptimizationBase.instantiate_function(
end

if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
extras_pushforward = prepare_pushforward(cons_oop, adtype, x)
extras_pushforward = prepare_pushforward(
cons_oop, adtype.dense_ad, x, ones(eltype(x), length(x)))
function cons_jvp!(J, θ, v)
pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward)
end
Expand Down Expand Up @@ -510,10 +509,11 @@ function OptimizationBase.instantiate_function(
end

lag_hess_prototype = f.lag_hess_prototype
if cons !== nothing && cons_h == true && f.lag_h === nothing
lag_hess_colors = f.lag_hess_colorvec
if cons !== nothing && f.lag_h === nothing && lag_h == true
lag_extras = prepare_hessian(
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
lag_hess_prototype = lag_extras.coloring_result.S[1:length(θ), 1:length(θ)]
lag_hess_prototype = lag_extras.coloring_result.S[1:length(x), 1:length(x)]
lag_hess_colors = lag_extras.coloring_result.color

function lag_h!(H::AbstractMatrix, θ, σ, λ)
Expand Down Expand Up @@ -587,14 +587,11 @@ end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false)
adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0; kwargs...)
x = cache.u0
p = cache.p

return OptimizationBase.instantiate_function(
f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h)
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
end

end
24 changes: 10 additions & 14 deletions src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function instantiate_function(
hess = nothing
end

if fgh == true && f.fgh !== nothing
if fgh == true && f.fgh === nothing
function fgh!(G, H, θ)
(y, _, _) = value_derivative_and_second_derivative!(
_f, G, H, soadtype, θ, extras_hess)
Expand Down Expand Up @@ -229,7 +229,7 @@ function instantiate_function(
if cons !== nothing && lag_h == true && f.lag_h === nothing
lag_extras = prepare_hessian(
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
lag_hess_prototype = zeros(Bool, length(x), length(x))
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)

function lag_h!(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
Expand Down Expand Up @@ -263,7 +263,7 @@ function instantiate_function(
1:length(θ), 1:length(θ)])
end
end

function lag_h!(h::AbstractVector, θ, σ, λ, p)
global _p = p
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
Expand Down Expand Up @@ -301,16 +301,12 @@ end

function instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AbstractADType, num_cons = 0,
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
adtype::ADTypes.AbstractADType, num_cons = 0;
kwargs...)
x = cache.u0
p = cache.p

return instantiate_function(f, x, adtype, p, num_cons; g = g, h = h, hv = hv,
fg = fg, fgh = fgh, cons_j = cons_j, cons_vjp = cons_vjp, cons_jvp = cons_jvp,
cons_h = cons_h, lag_h = lag_h)
return instantiate_function(f, x, adtype, p, num_cons; kwargs...)
end

function instantiate_function(
Expand Down Expand Up @@ -392,7 +388,7 @@ function instantiate_function(
hess = nothing
end

if fgh == true && f.fgh !== nothing
if fgh == true && f.fgh === nothing
function fgh!(θ)
(y, G, H) = value_derivative_and_second_derivative(_f, adtype, θ, extras_hess)
return y, G, H
Expand Down Expand Up @@ -511,7 +507,7 @@ function instantiate_function(
if cons !== nothing && lag_h == true && f.lag_h === nothing
lag_extras = prepare_hessian(
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
lag_hess_prototype = zeros(Bool, length(x), length(x))
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)

function lag_h!(θ, σ, λ)
if σ == zero(eltype(θ))
Expand Down Expand Up @@ -558,9 +554,9 @@ end

function instantiate_function(
f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AbstractADType, num_cons = 0)
adtype::ADTypes.AbstractADType, num_cons = 0; kwargs...)
x = cache.u0
p = cache.p

return instantiate_function(f, x, adtype, p, num_cons)
return instantiate_function(f, x, adtype, p, num_cons; kwargs...)
end
Loading

0 comments on commit 9f59f80

Please sign in to comment.