Skip to content

Commit

Permalink
updates for CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 11, 2024
1 parent 4a9737c commit 7cb4541
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 42 deletions.
3 changes: 3 additions & 0 deletions lib/OptimizationManopt/src/OptimizationManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ internal state.
abstract type AbstractManoptOptimizer end

SciMLBase.supports_opt_cache_interface(opt::AbstractManoptOptimizer) = true
SciMLBase.requiresgradient(opt::Union{GradientDescentOptimizer, ConjugateGradientDescentOptimizer, QuasiNewtonOptimizer, ConvexBundleOptimizer, FrankWolfeOptimizer}) = true
SciMLBase.requireshessian(opt::Union{AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer}) = true


function __map_optimizer_args!(cache::OptimizationCache,
opt::AbstractManoptOptimizer;
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationOptimJL/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using OptimizationOptimJL,
OptimizationOptimJL.Optim, Optimization, ForwardDiff, Zygote,
OptimizationOptimJL.Optim, Optimization, ForwardDiff, Zygote, ReverseDiff.
Random, ModelingToolkit, Optimization.OptimizationBase.DifferentiationInterface
using Test

Expand Down
68 changes: 35 additions & 33 deletions src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ struct Sophia
end

SciMLBase.supports_opt_cache_interface(opt::Sophia) = true
SciMLBase.requiresgradient(opt::Sophia) = true
SciMLBase.allowsfg(opt::Sophia) = true
SciMLBase.requireshessian(opt::Sophia) = true

function Sophia(; η = 1e-3, βs = (0.9, 0.999), ϵ = 1e-8, λ = 1e-1, k = 10,
ρ = 0.04)
Expand All @@ -18,11 +21,10 @@ end

clip(z, ρ) = max(min(z, ρ), -ρ)

function SciMLBase.__init(prob::OptimizationProblem, opt::Sophia,
data = Optimization.DEFAULT_DATA;
function SciMLBase.__init(prob::OptimizationProblem, opt::Sophia;
maxiters::Number = 1000, callback = (args...) -> (false),
progress = false, save_best = true, kwargs...)
return OptimizationCache(prob, opt, data; maxiters, callback, progress,
return OptimizationCache(prob, opt; maxiters, callback, progress,
save_best, kwargs...)
end

Expand Down Expand Up @@ -60,46 +62,46 @@ function SciMLBase.__solve(cache::OptimizationCache{
λ = uType(cache.opt.λ)
ρ = uType(cache.opt.ρ)

if cache.data != Optimization.DEFAULT_DATA
maxiters = length(cache.data)
data = cache.data
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)

if cache.p == SciMLBase.NullParameters()
data = OptimizationBase.DEFAULT_DATA
else
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
data = Optimization.take(cache.data, maxiters)
data = cache.p
end

maxiters = Optimization._check_and_convert_maxiters(maxiters)

f = cache.f
θ = copy(cache.u0)
gₜ = zero(θ)
mₜ = zero(θ)
hₜ = zero(θ)
for (i, d) in enumerate(data)
f.grad(gₜ, θ, d...)
x = cache.f(θ, cache.p, d...)
opt_state = Optimization.OptimizationState(; iter = i,
u = θ,
objective = first(x),
grad = gₜ,
original = nothing)
cb_call = cache.callback(θ, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
elseif cb_call
break
end
mₜ = βs[1] .* mₜ + (1 - βs[1]) .* gₜ
for _ in 1:maxiters
for (i, d) in enumerate(data)
f.grad(gₜ, θ, d)
x = cache.f(θ, cache.p, d...)
opt_state = Optimization.OptimizationState(; iter = i,
u = θ,
objective = first(x),
grad = gₜ,
original = nothing)
cb_call = cache.callback(θ, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
elseif cb_call
break
end
mₜ = βs[1] .* mₜ + (1 - βs[1]) .* gₜ

if i % cache.opt.k == 1
hₜ₋₁ = copy(hₜ)
u = randn(uType, length(θ))
f.hv(hₜ, θ, u, d...)
hₜ = βs[2] .* hₜ₋₁ + (1 - βs[2]) .* (u .* hₜ)
if i % cache.opt.k == 1
hₜ₋₁ = copy(hₜ)
u = randn(uType, length(θ))
f.hv(hₜ, θ, u, d)
hₜ = βs[2] .* hₜ₋₁ + (1 - βs[2]) .* (u .* hₜ)
end
θ = θ .- η * λ .* θ
θ = θ .-
η .* clip.(mₜ ./ max.(hₜ, Ref(ϵ)), Ref(ρ))
end
θ = θ .- η * λ .* θ
θ = θ .-
η .* clip.(mₜ ./ max.(hₜ, Ref(ϵ)), Ref(ρ))
end

return SciMLBase.build_solution(cache, cache.opt,
Expand Down
6 changes: 0 additions & 6 deletions test/ADtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ end
sol = solve(prob, Optim.Newton())
@test 10 * sol.objective < l1
@test sol.retcode == ReturnCode.Success

sol = Optimization.solve(prob,
Optimization.Sophia(; η = 0.5,
λ = 0.0),
maxiters = 1000)
@test 10 * sol.objective < l1
end

@testset "No constraint" begin
Expand Down
6 changes: 4 additions & 2 deletions test/minibatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ optfun = OptimizationFunction(loss_adjoint,
Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, pp, train_loader)

res1 = Optimization.solve(optprob, Optimisers.Adam(0.05),
callback = callback, maxiters = numEpochs)
sol = Optimization.solve(optprob,
Optimization.Sophia(; η = 0.5,
λ = 0.0),
maxiters = 1000)
@test 10res1.objective < l1

optfun = OptimizationFunction(loss_adjoint,
Expand Down

0 comments on commit 7cb4541

Please sign in to comment.