diff --git a/src/cache.jl b/src/cache.jl index 0980a0b..e7722f5 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -5,7 +5,7 @@ struct AnalysisResults constraints::Union{Nothing, Vector{AnalysisResult}} end -struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <: +struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, P, C, M} <: SciMLBase.AbstractOptimizationCache f::F reinit_cache::RC @@ -15,7 +15,6 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <: ucons::UC sense::S opt::O - data::D progress::P callback::C manifold::M @@ -23,7 +22,7 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <: solver_args::NamedTuple end -function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFAULT_DATA; +function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt; callback = DEFAULT_CALLBACK, maxiters::Union{Number, Nothing} = nothing, maxtime::Union{Number, Nothing} = nothing, @@ -150,13 +149,12 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons, prob.ucons, prob.sense, - opt, data, progress, callback, manifold, AnalysisResults(obj_res, cons_res), + opt, progress, callback, manifold, AnalysisResults(obj_res, cons_res), merge((; maxiters, maxtime, abstol, reltol), NamedTuple(kwargs))) end -function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt, - data = DEFAULT_DATA; +function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt; callback = DEFAULT_CALLBACK, maxiters::Union{Number, Nothing} = nothing, maxtime::Union{Number, Nothing} = nothing, @@ -164,7 +162,7 @@ function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt, reltol::Union{Number, Nothing} = nothing, progress = false, kwargs...) - return OptimizationCache(prob, opt, data; maxiters, maxtime, abstol, callback, + return OptimizationCache(prob, opt; maxiters, maxtime, abstol, callback, reltol, progress, kwargs...) end diff --git a/src/function.jl b/src/function.jl index f93dda7..c5d3e94 100644 --- a/src/function.jl +++ b/src/function.jl @@ -114,17 +114,87 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, ::SciMLBase.NoAD, p, num_cons = 0; kwargs...) - grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) - fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...) - hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) - fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...) - hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) + if f.grad === nothing + grad = nothing + else + function grad(G, x) + return f.grad(G, x, p) + end + if p != SciMLBase.NullParameters() + function grad(G, x, p) + return f.grad(G, x, p) + end + end + end + if f.fg === nothing + fg = nothing + else + function fg(G, x) + return f.fg(G, x, p) + end + if p != SciMLBase.NullParameters() + function fg(G, x, p) + return f.fg(G, x, p) + end + end + end + if f.hess === nothing + hess = nothing + else + function hess(H, x) + return f.hess(H, x, p) + end + if p != SciMLBase.NullParameters() + function hess(H, x, p) + return f.hess(H, x, p) + end + end + end + + if f.fgh === nothing + fgh = nothing + else + function fgh(G, H, x) + return f.fgh(G, H, x, p) + end + if p != SciMLBase.NullParameters() + function fgh(G, H, x, p) + return f.fgh(G, H, x, p) + end + end + end + + if f.hv === nothing + hv = nothing + else + function hv(H, x, v) + return f.hv(H, x, v, p) + end + if p != SciMLBase.NullParameters() + function hv(H, x, v, p) + return f.hv(H, x, v, p) + end + end + end + cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p) cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p) cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) - lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p) + + if f.lag_h === nothing + lag_h = nothing + else + function lag_h(res, x) + return f.lag_h(res, x, p) + end + if p != SciMLBase.NullParameters() + function lag_h(res, x, p) + return f.lag_h(res, x, p) + end + end + end hess_prototype = f.hess_prototype === nothing ? nothing : convert.(eltype(x), f.hess_prototype) cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :