Skip to content

Commit

Permalink
enzyme reverse mode in constraint jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 7, 2024
1 parent 4ba2a4c commit a7c5a89
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if cons !== nothing && cons_j == true && f.cons_j === nothing
if num_cons > length(x)
seeds = Enzyme.onehot(x)
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
else
seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
Jaccache = Tuple(zero(x) for i in 1:num_cons)
end
# if num_cons > length(x)
seeds = Enzyme.onehot(x)
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
# else
# seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
# Jaccache = Tuple(zero(x) for i in 1:num_cons)
# end

y = zeros(eltype(x), num_cons)

Expand All @@ -219,27 +219,26 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
Enzyme.make_zero!(Jaccache[i])
end
Enzyme.make_zero!(y)
if num_cons > length(θ)
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p))
for i in eachindex(θ)
if J isa Vector
J[i] = Jaccache[i][1]
else
copyto!(@view(J[:, i]), Jaccache[i])
end
end
else
Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
BatchDuplicated(θ, Jaccache), Const(p))
for i in 1:num_cons
if J isa Vector
J .= Jaccache[1]
else
copyto!(@view(J[i, :]), Jaccache[i])
end
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p))
for i in eachindex(θ)
if J isa Vector
J[i] = Jaccache[i][1]
else
copyto!(@view(J[:, i]), Jaccache[i])
end
end
# else
# Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
# BatchDuplicated(θ, Jaccache), Const(p))
# for i in 1:num_cons
# if J isa Vector
# J .= Jaccache[1]
# else
# J[i, :] = Jaccache[i]
# end
# end
# end
end
elseif cons_j == true && cons !== nothing
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
Expand Down

0 comments on commit a7c5a89

Please sign in to comment.