Skip to content

Commit

Permalink
Remove data from cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 10, 2024
1 parent a7c5a89 commit f0a527b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
12 changes: 5 additions & 7 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct AnalysisResults
constraints::Union{Nothing, Vector{AnalysisResult}}
end

struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <:
struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, P, C, M} <:
SciMLBase.AbstractOptimizationCache
f::F
reinit_cache::RC
Expand All @@ -15,15 +15,14 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <:
ucons::UC
sense::S
opt::O
data::D
progress::P
callback::C
manifold::M
analysis_results::AnalysisResults
solver_args::NamedTuple
end

function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFAULT_DATA;
function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
callback = DEFAULT_CALLBACK,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
Expand Down Expand Up @@ -150,21 +149,20 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA

return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons,
prob.ucons, prob.sense,
opt, data, progress, callback, manifold, AnalysisResults(obj_res, cons_res),
opt, progress, callback, manifold, AnalysisResults(obj_res, cons_res),
merge((; maxiters, maxtime, abstol, reltol),
NamedTuple(kwargs)))
end

function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt,
data = DEFAULT_DATA;
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt;
callback = DEFAULT_CALLBACK,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
progress = false,
kwargs...)
return OptimizationCache(prob, opt, data; maxiters, maxtime, abstol, callback,
return OptimizationCache(prob, opt; maxiters, maxtime, abstol, callback,
reltol, progress,
kwargs...)
end
Expand Down
82 changes: 76 additions & 6 deletions src/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,87 @@ end
function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x, ::SciMLBase.NoAD,
p, num_cons = 0; kwargs...)
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...)
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...)
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
if f.grad === nothing
grad = nothing
else
function grad(G, x)
return f.grad(G, x, p)
end
if p != SciMLBase.NullParameters()
function grad(G, x, p)
return f.grad(G, x, p)
end
end
end
if f.fg === nothing
fg = nothing
else
function fg(G, x)
return f.fg(G, x, p)
end
if p != SciMLBase.NullParameters()
function fg(G, x, p)
return f.fg(G, x, p)
end
end
end
if f.hess === nothing
hess = nothing
else
function hess(H, x)
return f.hess(H, x, p)
end
if p != SciMLBase.NullParameters()
function hess(H, x, p)
return f.hess(H, x, p)
end
end
end

if f.fgh === nothing
fgh = nothing
else
function fgh(G, H, x)
return f.fgh(G, H, x, p)
end
if p != SciMLBase.NullParameters()
function fgh(G, H, x, p)
return f.fgh(G, H, x, p)
end
end
end

if f.hv === nothing
hv = nothing
else
function hv(H, x, v)
return f.hv(H, x, v, p)
end
if p != SciMLBase.NullParameters()
function hv(H, x, v, p)
return f.hv(H, x, v, p)
end
end
end

cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p)
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p)
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p)

if f.lag_h === nothing
lag_h = nothing
else
function lag_h(res, x)
return f.lag_h(res, x, p)
end
if p != SciMLBase.NullParameters()
function lag_h(res, x, p)
return f.lag_h(res, x, p)
end
end
end
hess_prototype = f.hess_prototype === nothing ? nothing :
convert.(eltype(x), f.hess_prototype)
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
Expand Down

0 comments on commit f0a527b

Please sign in to comment.