diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index dc1b9d2a2..5056fd519 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -32,13 +32,14 @@ fixed_state_type(::Layers.HamiltonianNN{False}) = false include("ffjord.jl") include("neural_de.jl") - +include("otflow.jl") include("collocation.jl") include("multiple_shooting.jl") export NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, AugmentedNDELayer, NeuralODEMM export FFJORD, FFJORDDistribution +export OTFlow, OTFlowDistribution export DimMover export EpanechnikovKernel, UniformKernel, TriangularKernel, QuarticKernel, TriweightKernel, diff --git a/src/otflow.jl b/src/otflow.jl new file mode 100644 index 000000000..9b9de8d56 --- /dev/null +++ b/src/otflow.jl @@ -0,0 +1,88 @@ +struct OTFlow <: AbstractLuxLayer + d::Int + m::Int + r::Int +end + +OTFlow(d::Int, m::Int; r::Int = min(10, d)) = OTFlow(d, m, r) + +function Lux.initialparameters(rng::AbstractRNG, l::OTFlow) + w = randn(rng, Float32, l.m) .* 0.01f0 + A = randn(rng, Float32, l.r, l.d + 1) .* 0.01f0 + b = zeros(Float32, l.d + 1) + c = randn(rng, Float32, l.m) .* 0.01f0 + K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01f0 + K1 = randn(rng, Float32, l.m, l.m) .* 0.01f0 + b0 = zeros(Float32, l.m) + b1 = zeros(Float32, l.m) + return (; w, A, b, c, K0, K1, b0, b1) +end + +sigma(x) = log(exp(x) + exp(-x)) +sigma_prime(x) = tanh(x) +sigma_double_prime(x) = 1 - tanh(x)^2 + +function resnet_forward(x::AbstractVector, t::Real, ps) + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) + return u1 +end + +function potential(x::AbstractVector, t::Real, ps) + s = vcat(x, t) + N = resnet_forward(x, t, ps) + quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s + linear_term = ps.b' * s + neural_term = sum((ps.w .+ ps.c) .* N) + return neural_term + quadratic_term + linear_term +end + +function gradient(x::AbstractVector, t::Real, ps, d::Int) + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + z1 = (ps.w .+ ps.c) .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* (ps.w .+ ps.c)) + z0 = ps.K0' * (sigma_prime.(ps.K0 * s .+ ps.b0) .* z1) + grad = z0 + (ps.A' * ps.A) * s + ps.b + return grad[1:d] +end + +function trace(x::AbstractVector, t::Real, ps, d::Int) + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + z1 = (ps.w .+ ps.c) .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* (ps.w .+ ps.c)) + K0_E = ps.K0[:, 1:d] + A_E = ps.A[:, 1:d] + t0 = sum(sigma_double_prime.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2)) + J = Diagonal(sigma_prime.(ps.K0 * s .+ ps.b0)) * K0_E + t1 = sum(sigma_double_prime.(ps.K1 * u0 .+ ps.b1) .* (ps.w .+ ps.c) .* (ps.K1 * J) .^ 2) + trace_A = tr(A_E' * A_E) + return t0 + t1 + trace_A +end + +function (l::OTFlow)(xt::Tuple{AbstractVector, Real}, ps, st) + x, t = xt + v = -gradient(x, t, ps, l.d) + tr = -trace(x, t, ps, l.d) + return (v, tr), st +end + +function simple_loss(x::AbstractVector, t::Real, l::OTFlow, ps) + (v, tr), _ = l((x, t), ps, NamedTuple()) + return sum(v .^ 2) / 2 - tr +end + +function manual_gradient(x::AbstractVector, t::Real, l::OTFlow, ps) + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) + v = -gradient(x, t, ps, l.d) + tr = -trace(x, t, ps, l.d) + grad_w = u1 + grad_c = u1 + grad_A = (ps.A * s) * s' + grad_b = similar(ps.b) + return (w = grad_w, A = grad_A, b = grad_b, c = grad_c, + K0 = zeros(l.m, l.d + 1), K1 = zeros(l.m, l.m), + b0 = zeros(l.m), b1 = zeros(l.m)) +end diff --git a/test/collocation_tests.jl b/test/collocation_tests.jl index 756e77724..d25e800ba 100644 --- a/test/collocation_tests.jl +++ b/test/collocation_tests.jl @@ -1,56 +1,56 @@ -@testitem "Collocation" tags=[:layers] begin - using OrdinaryDiffEq +@testitem "Collocation" tags = [:layers] begin + using OrdinaryDiffEq - bounded_support_kernels = [EpanechnikovKernel(), UniformKernel(), TriangularKernel(), - QuarticKernel(), TriweightKernel(), TricubeKernel(), CosineKernel()] + bounded_support_kernels = [EpanechnikovKernel(), UniformKernel(), TriangularKernel(), + QuarticKernel(), TriweightKernel(), TricubeKernel(), CosineKernel()] - unbounded_support_kernels = [ - GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()] + unbounded_support_kernels = [ + GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()] - @testset "Kernel Functions" begin - ts = collect(-5.0:0.1:5.0) - @testset "Kernels with support from -1 to 1" begin - minus_one_index = findfirst(x -> ==(x, -1.0), ts) - plus_one_index = findfirst(x -> ==(x, 1.0), ts) - @testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels, - [0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0]) - ws = DiffEqFlux.calckernel.((kernel,), ts) - # t < -1 - @test all(ws[1:(minus_one_index - 1)] .== 0.0) - # t > 1 - @test all(ws[(plus_one_index + 1):end] .== 0.0) - # -1 < t <1 - @test all(ws[(minus_one_index + 1):(plus_one_index - 1)] .> 0.0) - # t = 0 - @test DiffEqFlux.calckernel(kernel, 0.0) == x0 - end - end - @testset "Kernels with unbounded support" begin - @testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, - [1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))]) - # t = 0 - @test DiffEqFlux.calckernel(kernel, 0.0) == x0 - end - end - end + @testset "Kernel Functions" begin + ts = collect(-5.0:0.1:5.0) + @testset "Kernels with support from -1 to 1" begin + minus_one_index = findfirst(x -> ==(x, -1.0), ts) + plus_one_index = findfirst(x -> ==(x, 1.0), ts) + @testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels, + [0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0]) + ws = DiffEqFlux.calckernel.((kernel,), ts) + # t < -1 + @test all(ws[1:(minus_one_index-1)] .== 0.0) + # t > 1 + @test all(ws[(plus_one_index+1):end] .== 0.0) + # -1 < t <1 + @test all(ws[(minus_one_index+1):(plus_one_index-1)] .> 0.0) + # t = 0 + @test DiffEqFlux.calckernel(kernel, 0.0) == x0 + end + end + @testset "Kernels with unbounded support" begin + @testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, + [1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))]) + # t = 0 + @test DiffEqFlux.calckernel(kernel, 0.0) == x0 + end + end + end - @testset "Collocation of data" begin - f(u, p, t) = p .* u - rc = 2 - ps = repeat([-0.001], rc) - tspan = (0.0, 50.0) - u0 = 3.4 .+ ones(rc) - t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000)) - prob = ODEProblem(f, u0, tspan, ps) - data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12)) - @testset "$kernel" for kernel in [ - bounded_support_kernels..., unbounded_support_kernels...] - u′, u = collocate_data(data, t, kernel, 0.003) - @test sum(abs2, u - data) < 1e-8 - end - @testset "$kernel" for kernel in [bounded_support_kernels...] - # Errors out as the bandwidth is too low - @test_throws ErrorException collocate_data(data, t, kernel, 0.001) - end - end + @testset "Collocation of data" begin + f(u, p, t) = p .* u + rc = 2 + ps = repeat([-0.001], rc) + tspan = (0.0, 50.0) + u0 = 3.4 .+ ones(rc) + t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000)) + prob = ODEProblem(f, u0, tspan, ps) + data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12)) + @testset "$kernel" for kernel in [ + bounded_support_kernels..., unbounded_support_kernels...] + u′, u = collocate_data(data, t, kernel, 0.003) + @test sum(abs2, u - data) < 1e-8 + end + @testset "$kernel" for kernel in [bounded_support_kernels...] + # Errors out as the bandwidth is too low + @test_throws ErrorException collocate_data(data, t, kernel, 0.001) + end + end end diff --git a/test/neural_dae_tests.jl b/test/neural_dae_tests.jl index ffc812a5a..e33645b4f 100644 --- a/test/neural_dae_tests.jl +++ b/test/neural_dae_tests.jl @@ -69,4 +69,4 @@ optprob = Optimization.OptimizationProblem(optfunc, p) res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001)) end -end +end \ No newline at end of file diff --git a/test/neural_de_tests.jl b/test/neural_de_tests.jl index 8bbd35a23..25ea98c86 100644 --- a/test/neural_de_tests.jl +++ b/test/neural_de_tests.jl @@ -1,326 +1,326 @@ -@testitem "NeuralODE" tags=[:basicneuralde] begin - using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random - import Flux - - rng = Xoshiro(0) - - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") - mp = Float32[0.1, 0.1] - x = Float32[2.0; 0.0] - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) - tspan = (0.0f0, 1.0f0) - - dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - end - - aug_dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - end - - @testset "u0: $(typeof(u0))" for u0 in (x, xs) - @testset "kwargs: $(kwargs))" for kwargs in ( - (; save_everystep = false, save_start = false), - (; abstol = 1e-12, reltol = 1e-12, - save_everystep = false, save_start = false), - (; save_everystep = false, save_start = false, sensealg = TrackerAdjoint()), - (; save_everystep = false, save_start = false, - sensealg = BacksolveAdjoint()), - (; saveat = 0.0f0:0.1f0:1.0f0), - (; saveat = 0.1f0), - (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), - (; saveat = 0.1f0, sensealg = TrackerAdjoint())) - node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) - pd, st = Lux.setup(rng, node) - pd = ComponentArray(pd) - grads = Zygote.gradient(sum ∘ first ∘ node, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - anode = AugmentedNDELayer(NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) - grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - end - end - end +@testitem "NeuralODE" tags = [:basicneuralde] begin + using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random + using Flux: Flux + + rng = Xoshiro(0) + + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + @testset "u0: $(typeof(u0))" for u0 in (x, xs) + @testset "kwargs: $(kwargs))" for kwargs in ( + (; save_everystep = false, save_start = false), + (; abstol = 1e-12, reltol = 1e-12, + save_everystep = false, save_start = false), + (; save_everystep = false, save_start = false, sensealg = TrackerAdjoint()), + (; save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint()), + (; saveat = 0.0f0:0.1f0:1.0f0), + (; saveat = 0.1f0), + (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), + (; saveat = 0.1f0, sensealg = TrackerAdjoint())) + node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) + pd, st = Lux.setup(rng, node) + pd = ComponentArray(pd) + grads = Zygote.gradient(sum ∘ first ∘ node, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + anode = AugmentedNDELayer(NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + end + end + end end -@testitem "NeuralDSDE" tags=[:basicneuralde] begin - using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random - import Flux - - rng = Xoshiro(0) - - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") - mp = Float32[0.1, 0.1] - x = Float32[2.0; 0.0] - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) - tspan = (0.0f0, 1.0f0) - - dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - end - - aug_dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - end - - diffusion = if nnlib == "Flux" - Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - end - - aug_diffusion = if nnlib == "Flux" - Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - end - - tspan = (0.0f0, 0.1f0) - @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x, xs), - solver in (EulerHeun(), LambaEM(), SOSRI()) - - sode = NeuralDSDE( - dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - pd, st = Lux.setup(rng, sode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - @test !iszero(grads[2][end]) - - sode = NeuralDSDE(aug_dudt, aug_diffusion, tspan, solver; - saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - anode = AugmentedNDELayer(sode, 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - @test !iszero(grads[2][end]) - end - end +@testitem "NeuralDSDE" tags = [:basicneuralde] begin + using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random + using Flux: Flux + + rng = Xoshiro(0) + + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + diffusion = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_diffusion = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + tspan = (0.0f0, 0.1f0) + @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x, xs), + solver in (EulerHeun(), LambaEM(), SOSRI()) + + sode = NeuralDSDE( + dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + + sode = NeuralDSDE(aug_dudt, aug_diffusion, tspan, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + anode = AugmentedNDELayer(sode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + end + end end -@testitem "NeuralSDE" tags=[:basicneuralde] begin - using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random - import Flux - - rng = Xoshiro(0) - - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") - mp = Float32[0.1, 0.1] - x = Float32[2.0; 0.0] - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) - tspan = (0.0f0, 1.0f0) - - dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - end - - aug_dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - end - - diffusion_sde = if nnlib == "Flux" - Flux.Chain( - Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 4), x -> reshape(x, 2, 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 4), x -> reshape(x, 2, 2)) - end - - aug_diffusion_sde = if nnlib == "Flux" - Flux.Chain( - Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 16), x -> reshape(x, 4, 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 16), x -> reshape(x, 4, 4)) - end - - @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), - solver in (EulerHeun(), LambaEM()) - - sode = NeuralSDE(dudt, diffusion_sde, tspan, 2, solver; - saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - pd, st = Lux.setup(rng, sode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - @test !iszero(grads[2][end]) - - sode = NeuralSDE(aug_dudt, aug_diffusion_sde, tspan, 4, solver; - saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - anode = AugmentedNDELayer(sode, 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - @test !iszero(grads[2][end]) - end - end +@testitem "NeuralSDE" tags = [:basicneuralde] begin + using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random + using Flux: Flux + + rng = Xoshiro(0) + + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + diffusion_sde = if nnlib == "Flux" + Flux.Chain( + Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 4), x -> reshape(x, 2, 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 4), x -> reshape(x, 2, 2)) + end + + aug_diffusion_sde = if nnlib == "Flux" + Flux.Chain( + Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 16), x -> reshape(x, 4, 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 16), x -> reshape(x, 4, 4)) + end + + @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), + solver in (EulerHeun(), LambaEM()) + + sode = NeuralSDE(dudt, diffusion_sde, tspan, 2, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + + sode = NeuralSDE(aug_dudt, aug_diffusion_sde, tspan, 4, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + anode = AugmentedNDELayer(sode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + end + end end -@testitem "NeuralCDDE" tags=[:basicneuralde] begin - using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random - import Flux - - rng = Xoshiro(0) - - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") - mp = Float32[0.1, 0.1] - x = Float32[2.0; 0.0] - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) - tspan = (0.0f0, 1.0f0) - - dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(6 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(6 => 50, tanh), Dense(50 => 2)) - end - - aug_dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(12 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(12 => 50, tanh), Dense(50 => 4)) - end - - @testset "NeuralCDDE u0: $(typeof(u0))" for u0 in (x, xs) - dode = NeuralCDDE(dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), - MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) - pd, st = Lux.setup(rng, dode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ dode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - dode = NeuralCDDE( - aug_dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), - MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) - anode = AugmentedNDELayer(dode, 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - end - end +@testitem "NeuralCDDE" tags = [:basicneuralde] begin + using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random + using Flux: Flux + + rng = Xoshiro(0) + + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(6 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(6 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(12 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(12 => 50, tanh), Dense(50 => 4)) + end + + @testset "NeuralCDDE u0: $(typeof(u0))" for u0 in (x, xs) + dode = NeuralCDDE(dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), + MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) + pd, st = Lux.setup(rng, dode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ dode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + dode = NeuralCDDE( + aug_dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), + MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) + anode = AugmentedNDELayer(dode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + end + end end -@testitem "DimMover" tags=[:basicneuralde] begin - using Random +@testitem "DimMover" tags = [:basicneuralde] begin + using Random - rng = Xoshiro(0) - r = rand(2, 3, 4, 5) - layer = DimMover() - ps, st = Lux.setup(rng, layer) + rng = Xoshiro(0) + r = rand(2, 3, 4, 5) + layer = DimMover() + ps, st = Lux.setup(rng, layer) - @test first(layer(r, ps, st))[:, :, :, 1] == r[:, :, 1, :] + @test first(layer(r, ps, st))[:, :, :, 1] == r[:, :, 1, :] end -@testitem "Neural DE CUDA" tags=[:cuda] skip=:(using LuxCUDA; !LuxCUDA.functional()) begin - using LuxCUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, ComponentArrays - import Flux - - CUDA.allowscalar(false) - - rng = Xoshiro(0) - - const gdev = gpu_device() - const cdev = cpu_device() - - @testset "Neural DE" begin - mp = Float32[0.1, 0.1] |> gdev - x = Float32[2.0; 0.0] |> gdev - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) |> gdev - tspan = (0.0f0, 1.0f0) - - dudt = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - aug_dudt = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - - @testset "Neural ODE" begin - @testset "u0: $(typeof(u0))" for u0 in (x, xs) - @testset "kwargs: $(kwargs))" for kwargs in ( - (; save_everystep = false, save_start = false), - (; save_everystep = false, save_start = false, - sensealg = TrackerAdjoint()), - (; save_everystep = false, save_start = false, - sensealg = BacksolveAdjoint()), - (; saveat = 0.0f0:0.1f0:1.0f0), - (; saveat = 0.1f0), - (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), - (; saveat = 0.1f0, sensealg = TrackerAdjoint())) - node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) - pd, st = Lux.setup(rng, node) - pd = ComponentArray(pd) |> gdev - st = st |> gdev - broken = hasfield(typeof(kwargs), :sensealg) && - ndims(u0) == 2 && - kwargs.sensealg isa TrackerAdjoint - @test begin - grads = Zygote.gradient(sum ∘ last ∘ first ∘ node, u0, pd, st) - CUDA.@allowscalar begin - !iszero(grads[1]) && !iszero(grads[2]) - end - end broken=broken - - anode = AugmentedNDELayer( - NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) |> gdev - st = st |> gdev - @test begin - grads = Zygote.gradient(sum ∘ last ∘ first ∘ anode, u0, pd, st) - CUDA.@allowscalar begin - !iszero(grads[1]) && !iszero(grads[2]) - end - end broken=broken - end - end - end - - diffusion = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - aug_diffusion = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - - tspan = (0.0f0, 0.1f0) - @testset "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), - solver in (SOSRI(),) - # CuVector seems broken on CI but I can't reproduce the failure locally - - sode = NeuralDSDE( - dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - pd, st = Lux.setup(rng, sode) - pd = ComponentArray(pd) |> gdev - st = st |> gdev - - @test_broken begin - grads = Zygote.gradient(sum ∘ last ∘ first ∘ sode, u0, pd, st) - CUDA.@allowscalar begin - !iszero(grads[1]) && !iszero(grads[2]) && !iszero(grads[2][end]) - end - end - end - end +@testitem "Neural DE CUDA" tags = [:cuda] skip = :(using LuxCUDA; !LuxCUDA.functional()) begin + using LuxCUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, ComponentArrays + using Flux: Flux + + CUDA.allowscalar(false) + + rng = Xoshiro(0) + + const gdev = gpu_device() + const cdev = cpu_device() + + @testset "Neural DE" begin + mp = Float32[0.1, 0.1] |> gdev + x = Float32[2.0; 0.0] |> gdev + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) |> gdev + tspan = (0.0f0, 1.0f0) + + dudt = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + aug_dudt = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + + @testset "Neural ODE" begin + @testset "u0: $(typeof(u0))" for u0 in (x, xs) + @testset "kwargs: $(kwargs))" for kwargs in ( + (; save_everystep = false, save_start = false), + (; save_everystep = false, save_start = false, + sensealg = TrackerAdjoint()), + (; save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint()), + (; saveat = 0.0f0:0.1f0:1.0f0), + (; saveat = 0.1f0), + (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), + (; saveat = 0.1f0, sensealg = TrackerAdjoint())) + node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) + pd, st = Lux.setup(rng, node) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + broken = hasfield(typeof(kwargs), :sensealg) && + ndims(u0) == 2 && + kwargs.sensealg isa TrackerAdjoint + @test begin + grads = Zygote.gradient(sum ∘ last ∘ first ∘ node, u0, pd, st) + CUDA.@allowscalar begin + !iszero(grads[1]) && !iszero(grads[2]) + end + end broken = broken + + anode = AugmentedNDELayer( + NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + @test begin + grads = Zygote.gradient(sum ∘ last ∘ first ∘ anode, u0, pd, st) + CUDA.@allowscalar begin + !iszero(grads[1]) && !iszero(grads[2]) + end + end broken = broken + end + end + end + + diffusion = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + aug_diffusion = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + + tspan = (0.0f0, 0.1f0) + @testset "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), + solver in (SOSRI(),) + # CuVector seems broken on CI but I can't reproduce the failure locally + + sode = NeuralDSDE( + dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + + @test_broken begin + grads = Zygote.gradient(sum ∘ last ∘ first ∘ sode, u0, pd, st) + CUDA.@allowscalar begin + !iszero(grads[1]) && !iszero(grads[2]) && !iszero(grads[2][end]) + end + end + end + end end diff --git a/test/otflow_tests.jl b/test/otflow_tests.jl new file mode 100644 index 000000000..d131941bc --- /dev/null +++ b/test/otflow_tests.jl @@ -0,0 +1,65 @@ +@testitem "Tests for OTFlow Layer Functionality" begin + using Lux, LuxCore, Random, LinearAlgebra, Test, ComponentArrays, Flux, DiffEqFlux + rng = Xoshiro(0) + d = 2 + m = 4 + r = 2 + otflow = OTFlow(d, m; r = r) + ps, st = Lux.setup(rng, otflow) + ps = ComponentArray(ps) + + x = Float32[1.0, 2.0] + t = 0.5f0 + + @testset "Forward Pass" begin + (v, tr), st_new = otflow((x, t), ps, st) + @test length(v) == d + @test isa(tr, Float32) + @test st_new == st + end + + @testset "Potential Function" begin + phi = potential(x, t, ps) + @test isa(phi, Float32) + end + + @testset "Gradient Consistency" begin + grad = gradient(x, t, ps, d) + (v, _), _ = otflow((x, t), ps, st) + @test length(grad) == d + @test grad ≈ -v atol = 1e-5 # v = -∇Φ + end + + @testset "Trace Consistency" begin + tr_manual = trace(x, t, ps, d) + (_, tr_forward), _ = otflow((x, t), ps, st) + @test tr_manual ≈ -tr_forward atol = 1e-5 + end + + @testset "ODE Integration" begin + x0 = Float32[1.0, 1.0] + tspan = (0.0f0, 1.0f0) + x_traj, t_vec = simple_ode_solve(otflow, x0, tspan, ps, st; dt = 0.01f0) + @test size(x_traj) == (d, length(t_vec)) + @test all(isfinite, x_traj) + @test x_traj[:, end] != x0 + end + + @testset "Loss Function" begin + loss_val = simple_loss(x, t, otflow, ps) + @test isa(loss_val, Float32) + @test isfinite(loss_val) + end + + @testset "Manual Gradient" begin + grads = manual_gradient(x, t, otflow, ps) + @test haskey(grads, :w) && length(grads.w) == m + @test haskey(grads, :A) && size(grads.A) == (r, d + 1) + @test haskey(grads, :b) && length(grads.b) == d + 1 + @test haskey(grads, :c) && isa(grads.c, Float32) + @test haskey(grads, :K0) && size(grads.K0) == (m, d + 1) + @test haskey(grads, :K1) && size(grads.K1) == (m, m) + @test haskey(grads, :b0) && length(grads.b0) == m + @test haskey(grads, :b1) && length(grads.b1) == m + end +end