From a7c5a89ed9a2666f9776570ae74adc47ce94072d Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sat, 7 Sep 2024 09:55:55 -0400 Subject: [PATCH] enzyme reverse mode in constraint jacobian --- ext/OptimizationEnzymeExt.jl | 51 ++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index c90a22e..5bc6c2e 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -204,13 +204,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if cons !== nothing && cons_j == true && f.cons_j === nothing - if num_cons > length(x) - seeds = Enzyme.onehot(x) - Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) - else - seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) - Jaccache = Tuple(zero(x) for i in 1:num_cons) - end + # if num_cons > length(x) + seeds = Enzyme.onehot(x) + Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) + # else + # seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) + # Jaccache = Tuple(zero(x) for i in 1:num_cons) + # end y = zeros(eltype(x), num_cons) @@ -219,27 +219,26 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(Jaccache[i]) end Enzyme.make_zero!(y) - if num_cons > length(θ) - Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), - BatchDuplicated(θ, seeds), Const(p)) - for i in eachindex(θ) - if J isa Vector - J[i] = Jaccache[i][1] - else - copyto!(@view(J[:, i]), Jaccache[i]) - end - end - else - Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds), - BatchDuplicated(θ, Jaccache), Const(p)) - for i in 1:num_cons - if J isa Vector - J .= Jaccache[1] - else - copyto!(@view(J[i, :]), Jaccache[i]) - end + Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), + BatchDuplicated(θ, seeds), Const(p)) + for i in eachindex(θ) + if J isa Vector + J[i] = Jaccache[i][1] + else + copyto!(@view(J[:, i]), Jaccache[i]) end end + # else + # Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds), + # BatchDuplicated(θ, Jaccache), Const(p)) + # for i in 1:num_cons + # if J isa Vector + # J .= Jaccache[1] + # else + # J[i, :] = Jaccache[i] + # end + # end + # end end elseif cons_j == true && cons !== nothing cons_j! = (J, θ) -> f.cons_j(J, θ, p)