diff --git a/Project.toml b/Project.toml index 3e80adf9f..d204b8fe4 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ LBFGSB = "0.4.1" LinearAlgebra = "1.10" Logging = "1.10" LoggingExtras = "0.4, 1" -OptimizationBase = "2.0.1" +OptimizationBase = "2.0.2" Printf = "1.10" ProgressLogging = "0.1" Reexport = "1.2" diff --git a/lib/OptimizationManopt/src/OptimizationManopt.jl b/lib/OptimizationManopt/src/OptimizationManopt.jl index fdee579d1..7ec58d004 100644 --- a/lib/OptimizationManopt/src/OptimizationManopt.jl +++ b/lib/OptimizationManopt/src/OptimizationManopt.jl @@ -13,6 +13,9 @@ internal state. abstract type AbstractManoptOptimizer end SciMLBase.supports_opt_cache_interface(opt::AbstractManoptOptimizer) = true +SciMLBase.requiresgradient(opt::Union{GradientDescentOptimizer, ConjugateGradientDescentOptimizer, QuasiNewtonOptimizer, ConvexBundleOptimizer, FrankWolfeOptimizer}) = true +SciMLBase.requireshessian(opt::Union{AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer}) = true + function __map_optimizer_args!(cache::OptimizationCache, opt::AbstractManoptOptimizer; diff --git a/lib/OptimizationOptimJL/test/runtests.jl b/lib/OptimizationOptimJL/test/runtests.jl index a75e5987c..06c9c10dc 100644 --- a/lib/OptimizationOptimJL/test/runtests.jl +++ b/lib/OptimizationOptimJL/test/runtests.jl @@ -1,5 +1,5 @@ using OptimizationOptimJL, - OptimizationOptimJL.Optim, Optimization, ForwardDiff, Zygote, + OptimizationOptimJL.Optim, Optimization, ForwardDiff, Zygote, ReverseDiff. Random, ModelingToolkit, Optimization.OptimizationBase.DifferentiationInterface using Test diff --git a/src/sophia.jl b/src/sophia.jl index 30e86c0ff..00b6b9ebe 100644 --- a/src/sophia.jl +++ b/src/sophia.jl @@ -10,6 +10,9 @@ struct Sophia end SciMLBase.supports_opt_cache_interface(opt::Sophia) = true +SciMLBase.requiresgradient(opt::Sophia) = true +SciMLBase.allowsfg(opt::Sophia) = true +SciMLBase.requireshessian(opt::Sophia) = true function Sophia(; η = 1e-3, βs = (0.9, 0.999), ϵ = 1e-8, λ = 1e-1, k = 10, ρ = 0.04) @@ -18,11 +21,10 @@ end clip(z, ρ) = max(min(z, ρ), -ρ) -function SciMLBase.__init(prob::OptimizationProblem, opt::Sophia, - data = Optimization.DEFAULT_DATA; +function SciMLBase.__init(prob::OptimizationProblem, opt::Sophia; maxiters::Number = 1000, callback = (args...) -> (false), progress = false, save_best = true, kwargs...) - return OptimizationCache(prob, opt, data; maxiters, callback, progress, + return OptimizationCache(prob, opt; maxiters, callback, progress, save_best, kwargs...) end @@ -60,46 +62,46 @@ function SciMLBase.__solve(cache::OptimizationCache{ λ = uType(cache.opt.λ) ρ = uType(cache.opt.ρ) - if cache.data != Optimization.DEFAULT_DATA - maxiters = length(cache.data) - data = cache.data + maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + + if cache.p == SciMLBase.NullParameters() + data = OptimizationBase.DEFAULT_DATA else - maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) - data = Optimization.take(cache.data, maxiters) + data = cache.p end - maxiters = Optimization._check_and_convert_maxiters(maxiters) - f = cache.f θ = copy(cache.u0) gₜ = zero(θ) mₜ = zero(θ) hₜ = zero(θ) - for (i, d) in enumerate(data) - f.grad(gₜ, θ, d...) - x = cache.f(θ, cache.p, d...) - opt_state = Optimization.OptimizationState(; iter = i, - u = θ, - objective = first(x), - grad = gₜ, - original = nothing) - cb_call = cache.callback(θ, x...) - if !(cb_call isa Bool) - error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.") - elseif cb_call - break - end - mₜ = βs[1] .* mₜ + (1 - βs[1]) .* gₜ + for _ in 1:maxiters + for (i, d) in enumerate(data) + f.grad(gₜ, θ, d) + x = cache.f(θ, cache.p, d...) + opt_state = Optimization.OptimizationState(; iter = i, + u = θ, + objective = first(x), + grad = gₜ, + original = nothing) + cb_call = cache.callback(θ, x...) + if !(cb_call isa Bool) + error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.") + elseif cb_call + break + end + mₜ = βs[1] .* mₜ + (1 - βs[1]) .* gₜ - if i % cache.opt.k == 1 - hₜ₋₁ = copy(hₜ) - u = randn(uType, length(θ)) - f.hv(hₜ, θ, u, d...) - hₜ = βs[2] .* hₜ₋₁ + (1 - βs[2]) .* (u .* hₜ) + if i % cache.opt.k == 1 + hₜ₋₁ = copy(hₜ) + u = randn(uType, length(θ)) + f.hv(hₜ, θ, u, d) + hₜ = βs[2] .* hₜ₋₁ + (1 - βs[2]) .* (u .* hₜ) + end + θ = θ .- η * λ .* θ + θ = θ .- + η .* clip.(mₜ ./ max.(hₜ, Ref(ϵ)), Ref(ρ)) end - θ = θ .- η * λ .* θ - θ = θ .- - η .* clip.(mₜ ./ max.(hₜ, Ref(ϵ)), Ref(ρ)) end return SciMLBase.build_solution(cache, cache.opt, diff --git a/test/ADtests.jl b/test/ADtests.jl index 7243ad121..dca8ebf34 100644 --- a/test/ADtests.jl +++ b/test/ADtests.jl @@ -30,12 +30,6 @@ end sol = solve(prob, Optim.Newton()) @test 10 * sol.objective < l1 @test sol.retcode == ReturnCode.Success - - sol = Optimization.solve(prob, - Optimization.Sophia(; η = 0.5, - λ = 0.0), - maxiters = 1000) - @test 10 * sol.objective < l1 end @testset "No constraint" begin diff --git a/test/minibatch.jl b/test/minibatch.jl index 0c4e2393e..2a755e36f 100644 --- a/test/minibatch.jl +++ b/test/minibatch.jl @@ -58,8 +58,10 @@ optfun = OptimizationFunction(loss_adjoint, Optimization.AutoZygote()) optprob = OptimizationProblem(optfun, pp, train_loader) -res1 = Optimization.solve(optprob, Optimisers.Adam(0.05), - callback = callback, maxiters = numEpochs) +sol = Optimization.solve(optprob, +Optimization.Sophia(; η = 0.5, + λ = 0.0), +maxiters = 1000) @test 10res1.objective < l1 optfun = OptimizationFunction(loss_adjoint,