Skip to content

Commit

Permalink
Work with stochastic optimizers too
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 26, 2024
1 parent 3ca61cf commit b4287a1
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 34 deletions.
1 change: 1 addition & 0 deletions src/Optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include("utils.jl")
include("state.jl")
include("lbfgsb.jl")
include("sophia.jl")
include("auglag.jl")

export solve

Expand Down
63 changes: 30 additions & 33 deletions src/auglag.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@

SciMLBase.supports_opt_cache_interface(::LBFGS) = true
SciMLBase.allowsbounds(::LBFGS) = true
SciMLBase.requiresgradient(::LBFGS) = true
SciMLBase.allowsconstraints(::LBFGS) = true
SciMLBase.requiresconsjac(::LBFGS) = true

function task_message_to_string(task::Vector{UInt8})
return String(task)
@kwdef struct AugLag
inner
τ = 0.5
γ = 10.0
λmin = -1e20
λmax = 1e20
μmin = 0.0
μmax = 1e20
ϵ = 1e-8
end

function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS;
SciMLBase.supports_opt_cache_interface(::AugLag) = true
SciMLBase.allowsbounds(::AugLag) = true
SciMLBase.requiresgradient(::AugLag) = true
SciMLBase.allowsconstraints(::AugLag) = true
SciMLBase.requiresconsjac(::AugLag) = true

function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag;
callback = nothing,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
Expand Down Expand Up @@ -62,7 +68,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
UC,
S,
O <:
LBFGS,
AugLag,
D,
P,
C
Expand Down Expand Up @@ -90,10 +96,10 @@ if !isnothing(cache.f.cons)

cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
cache.f.cons(cons_tmp, cache.u0)
ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp)))
ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp)))

_loss = function (θ)
x = cache.f(θ, cache.p)
_loss = function, p = cache.p)
x = cache.f(θ, p)
cons_tmp .= zero(eltype(θ))
cache.f.cons(cons_tmp, θ)
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
Expand All @@ -114,8 +120,8 @@ if !isnothing(cache.f.cons)
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
eqidxs = eqidxs[eqidxs .!= nothing]
ineqidxs = ineqidxs[ineqidxs .!= nothing]
function aug_grad(G, θ)
cache.f.grad(G, θ)
function aug_grad(G, θ, p)
cache.f.grad(G, θ, p)
if !isnothing(cache.f.cons_jac_prototype)
J = Float64.(cache.f.cons_jac_prototype)
else
Expand All @@ -139,23 +145,15 @@ if !isnothing(cache.f.cons)
opt_ret = ReturnCode.MaxIters
n = length(cache.u0)

sol = solve(....)
augprob = OptimizationProblem(OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p)

solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))

for i in 1:maxiters
for i in 1:(maxiters/10)
prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds]
prevβ .= copy(β)
res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs...,
m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100)
# @show res[2]
# @show res[1]
# @show cons_tmp
# @show λ
# @show β
# @show μ
# @show ρ
θ = res[2]
res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10)
θ = res.u
cons_tmp .= 0.0
cache.f.cons(cons_tmp, θ)
λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin)
Expand All @@ -172,11 +170,10 @@ if !isnothing(cache.f.cons)
break
end
end
end

stats = Optimization.OptimizationStats(; iterations = maxiters,
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = 0.0, fevals = maxiters, gevals = maxiters)
return SciMLBase.build_solution(
cache, cache.opt, res[2], cache.f(res[2], cache.p)[1],
return SciMLBase.build_solution(
cache, cache.opt, θ, x,
stats = stats, retcode = opt_ret)
end
end
2 changes: 1 addition & 1 deletion src/lbfgsb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS;
@warn "common abstol is currently not used by $(opt)"
end
if !isnothing(maxtime)
@warn "common abstol is currently not used by $(opt)"
@warn "common maxtime is currently not used by $(opt)"
end

mapped_args = (;)
Expand Down
28 changes: 28 additions & 0 deletions test/lbfgsb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,31 @@ prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf],
ub = [1.0, 1.0])
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
@test res.retcode == SciMLBase.ReturnCode.Success

using MLUtils, OptimizationOptimisers

x0 = -pi:0.001:pi
y0 = sin.(x0)
data = MLUtils.DataLoader((x0, y0), batchsize = 100)
function loss(coeffs, data)
ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])]
return sum(abs2, ypred .- data[2])
end

function cons1(res, coeffs, p = nothing)
res[1] = coeffs[1] * coeffs[5] - 1
return nothing
end

optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1)
callback = (st, l) -> (@show l; return false)

prob = OptimizationProblem(optf, rand(5), (x0, y0), lcons = [-0.5], ucons = [0.5], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback)

prob = OptimizationProblem(optf, rand(5), data, lcons = [0.0], ucons = [0.0], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
opt = solve(prob, Optimization.AugLag(; inner = Adam()), maxiters = 500, callback = callback)

optf1 = OptimizationFunction(loss, AutoSparseForwardDiff())
prob1 = OptimizationProblem(optf1, rand(5), data)
sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback)

0 comments on commit b4287a1

Please sign in to comment.