From 8d7cd3acf4cdbe7bd6441d5eb2e9bfb7154875d7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 15:49:56 -0400 Subject: [PATCH 1/5] feat: make MLUtils into a weakdep & suppport MLDataDevices --- lib/OptimizationOptimisers/Project.toml | 12 ++++++++++-- .../ext/OptimizationOptimisersMLDataDevicesExt.jl | 8 ++++++++ .../ext/OptimizationOptimisersMLUtilsExt.jl | 8 ++++++++ .../src/OptimizationOptimisers.jl | 7 +++++-- 4 files changed, 31 insertions(+), 4 deletions(-) create mode 100644 lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl create mode 100644 lib/OptimizationOptimisers/ext/OptimizationOptimisersMLUtilsExt.jl diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index dbc5aecd2..b7356ce20 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -1,17 +1,25 @@ name = "OptimizationOptimisers" uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1" authors = ["Vaibhav Dixit and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +[extensions] +OptimizationOptimisersMLDataDevicesExt = "MLDataDevices" +OptimizationOptimisersMLUtilsExt = "MLUtils" + +[weakdeps] +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + [compat] +MLDataDevices = "1.1" MLUtils = "0.4.4" Optimisers = "0.2, 0.3" Optimization = "4" diff --git a/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl new file mode 100644 index 000000000..545f73c6c --- /dev/null +++ b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl @@ -0,0 +1,8 @@ +module OptimizationOptimisersMLDataDevicesExt + +using MLDataDevices +using OptimizationOptimisers + +OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = true + +end diff --git a/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLUtilsExt.jl b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLUtilsExt.jl new file mode 100644 index 000000000..1790d7aea --- /dev/null +++ b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLUtilsExt.jl @@ -0,0 +1,8 @@ +module OptimizationOptimisersMLUtilsExt + +using MLUtils +using OptimizationOptimisers + +OptimizationOptimisers.isa_dataiterator(::MLUtils.DataLoader) = true + +end diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index b3811bbd7..12b021da3 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -2,7 +2,7 @@ module OptimizationOptimisers using Reexport, Printf, ProgressLogging @reexport using Optimisers, Optimization -using Optimization.SciMLBase, MLUtils +using Optimization.SciMLBase SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true SciMLBase.requiresgradient(opt::AbstractRule) = true @@ -16,6 +16,8 @@ function SciMLBase.__init( kwargs...) end +isa_dataiterator(data) = false + function SciMLBase.__solve(cache::OptimizationCache{ F, RC, @@ -57,13 +59,14 @@ function SciMLBase.__solve(cache::OptimizationCache{ throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg.")) end - if cache.p isa MLUtils.DataLoader + if isa_dataiterator(cache.p) data = cache.p dataiterate = true else data = [cache.p] dataiterate = false end + opt = cache.opt θ = copy(cache.u0) G = copy(θ) From 5a76a9e6fb8d39525a8e4df63829f96dcc7f9414 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 22 Sep 2024 00:03:57 -0400 Subject: [PATCH 2/5] Add minibatching tests --- lib/OptimizationOptimisers/Project.toml | 8 ++-- .../OptimizationOptimisersMLDataDevicesExt.jl | 2 +- .../src/OptimizationOptimisers.jl | 2 +- lib/OptimizationOptimisers/test/runtests.jl | 41 +++++++++++++++++++ 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index b7356ce20..2c4c8cc5e 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -10,14 +10,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -[extensions] -OptimizationOptimisersMLDataDevicesExt = "MLDataDevices" -OptimizationOptimisersMLUtilsExt = "MLUtils" - [weakdeps] MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +[extensions] +OptimizationOptimisersMLDataDevicesExt = "MLDataDevices" +OptimizationOptimisersMLUtilsExt = "MLUtils" + [compat] MLDataDevices = "1.1" MLUtils = "0.4.4" diff --git a/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl index 545f73c6c..ed5020daa 100644 --- a/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl +++ b/lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl @@ -3,6 +3,6 @@ module OptimizationOptimisersMLDataDevicesExt using MLDataDevices using OptimizationOptimisers -OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = true +OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true) end diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 12b021da3..ea2ef9202 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -117,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ opt = min_opt x = min_err θ = min_θ - cache.f.grad(G, θ, d...) + cache.f.grad(G, θ, d) opt_state = Optimization.OptimizationState(iter = i, u = θ, objective = x[1], diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index ddee2ea4c..867b03b36 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -68,3 +68,44 @@ using Zygote @test_throws ArgumentError sol=solve(prob, Optimisers.Adam()) end + +@testset "Minibatching" begin + using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Statistics, Plots, + Random, ComponentArrays + + x = rand(10000) + y = sin.(x) + data = MLUtils.DataLoader((x, y), batchsize = 100) + + # Define the neural network + model = Chain(Dense(1, 32, tanh), Dense(32, 1)) + ps, st = Lux.setup(Random.default_rng(), model) + ps_ca = ComponentArray(ps) + smodel = StatefulLuxLayer{true}(model, nothing, st) + + function callback(state, l) + state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l + return l < 1e-4 + end + + function loss(ps, data) + ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])] + return sum(abs2, ypred .- data[2]) + end + + optf = OptimizationFunction(loss, AutoZygote()) + prob = OptimizationProblem(optf, ps_ca, data) + + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100) + + @test res.objective < 1e-4 + + using MLDataDevices + data = CPUDevice()(data) + optf = OptimizationFunction(loss, AutoZygote()) + prob = OptimizationProblem(optf, ps_ca, data) + + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100) + + @test res.objective < 1e-4 +end From 6e55f158a110515985b84b201970c828d0c74fc8 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 22 Sep 2024 00:28:45 -0400 Subject: [PATCH 3/5] add test deps needed --- lib/OptimizationOptimisers/Project.toml | 5 ++++- lib/OptimizationOptimisers/test/runtests.jl | 3 +-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index 2c4c8cc5e..8e1585a0b 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -28,9 +28,12 @@ Reexport = "1.2" julia = "1" [extras] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ForwardDiff", "Test", "Zygote"] +test = ["ComponentArrays", "ForwardDiff", "Lux", "Random", "Test", "Zygote"] diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 867b03b36..418b1547c 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -70,8 +70,7 @@ using Zygote end @testset "Minibatching" begin - using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Statistics, Plots, - Random, ComponentArrays + using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random, ComponentArrays x = rand(10000) y = sin.(x) From f05d0510022fd694068642afb32e9ea07dbb19a7 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 22 Sep 2024 00:53:02 -0400 Subject: [PATCH 4/5] more test deps --- lib/OptimizationOptimisers/Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index 8e1585a0b..371467455 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -31,9 +31,11 @@ julia = "1" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ComponentArrays", "ForwardDiff", "Lux", "Random", "Test", "Zygote"] +test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"] From 1f4cba3995df4ed0f0f1a7c6c845153d03687416 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 22 Sep 2024 01:23:46 -0400 Subject: [PATCH 5/5] bump epochs --- lib/OptimizationOptimisers/test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 418b1547c..02b764df2 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -95,7 +95,7 @@ end optf = OptimizationFunction(loss, AutoZygote()) prob = OptimizationProblem(optf, ps_ca, data) - res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100) + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000) @test res.objective < 1e-4 @@ -104,7 +104,7 @@ end optf = OptimizationFunction(loss, AutoZygote()) prob = OptimizationProblem(optf, ps_ca, data) - res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100) + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000) @test res.objective < 1e-4 end