diff --git a/Project.toml b/Project.toml index 0375150..cf3032f 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,7 @@ version = "0.1.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/examples/componentarrays.jl b/examples/componentarrays.jl new file mode 100644 index 0000000..d142e61 --- /dev/null +++ b/examples/componentarrays.jl @@ -0,0 +1,39 @@ +using ComponentArrays, DualArrays + + +function denselayer(layer, x) + layer.weight * x + layer.bias +end + + + +# We want to different with respect to b the following function: + +function foo(b) + denselayer(ComponentVector(weight = [1 2; 3 4], bias = b), [5,6]) +end + +# This works as-is: + +foo([5,6]) +foo(DualVector([5,6], I(2))) + +b = DualVector([5,6], I(2)) +w = DualArray([1 2; 3 4], FourTensorIdentity((2,2), (2,2))) +ComponentVector(weight = w, bias = b) + + +denselayer(ComponentVector(weight = [1 2; 3 4], bias = b), [5,6]) + +f = x -> x * derivative(y -> x + y, 1) + + +g = (x,y) -> x + y +f = x -> x * g(x, 1) + + + +function f(a, b) + a + b +end + diff --git a/src/DualArrays.jl b/src/DualArrays.jl index e4074d5..fceeff4 100644 --- a/src/DualArrays.jl +++ b/src/DualArrays.jl @@ -2,16 +2,21 @@ module DualArrays export DualVector import Base: +, ==, getindex, size, broadcast, axes, broadcasted, show, sum, - vcat, convert, * + vcat, convert, *, promote_rule using LinearAlgebra, ArrayLayouts, BandedMatrices, FillArrays +import FillArrays: OneElementVector + struct Dual{T, Partials <: AbstractVector{T}} <: Real value::T partials::Partials end + ==(a::Dual, b::Dual) = a.value == b.value && a.partials == b.partials + + sparse_getindex(a...) = layout_getindex(a...) sparse_getindex(D::Diagonal, k::Integer, ::Colon) = OneElement(D.diag[k], k, size(D,2)) sparse_getindex(D::Diagonal, ::Colon, j::Integer) = OneElement(D.diag[j], j, size(D,1)) @@ -24,7 +29,10 @@ reprents a vector of duals given by For now the entries just return the values. """ -struct DualVector{T, M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}} +partials_type(::Type{Matrix{T}}) where T = Vector{T} +partials_type(::Type{<:Diagonal{T}}) where T = OneElementVector{T} + +struct DualVector{T, M <: AbstractMatrix{T}, Partials} <: AbstractVector{Dual{T, Partials}} value::Vector{T} jacobian::M function DualVector(value::Vector{T},jacobian::M) where {T, M <: AbstractMatrix{T}} @@ -34,7 +42,7 @@ struct DualVector{T, M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}} vector length: $x \n no. of jacobian rows: $y")) end - new{T,M}(value,jacobian) + new{T,M,partials_type(M)}(value,jacobian) end end @@ -44,6 +52,18 @@ function DualVector(value::AbstractVector, jacobian::AbstractMatrix) end +#### +# convert/promotion +#### + +# TODO: convert_eltype for M +promote_rule(::Type{Vector{T}}, ::Type{DualVector{V,M}}) where {T,V,M} = DualVector{promote_type(T,V),M} +convert(::Type{DualVector{V,M}}, x::AbstractVector{T}) where {T,V,M} = DualVector(convert(AbstractVector{V}, x), convert(M, Zeros(axes(x,1), axes(x,1)))) + + +##### +# getindex/setindex! +#### function getindex(x::DualVector, y::Int) Dual(x.value[y], sparse_getindex(x.jacobian,y,:)) diff --git a/test/runtests.jl b/test/runtests.jl index d1cba3c..fe14d88 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using DualArrays, Test, LinearAlgebra, ForwardDiff, BandedMatrices +using DualArrays, Test, LinearAlgebra, ForwardDiff, BandedMatrices, ComponentArrays using DualArrays: Dual @testset "DualArrays" begin @@ -39,4 +39,10 @@ using DualArrays: Dual @test vcat(x) == x @test vcat(x, x) == DualVector([1, 1], [1 2 3;1 2 3]) @test vcat(x, y) == DualVector([1, 2, 3], [1 2 3;4 5 6;7 8 9]) -end \ No newline at end of file + + @testset "promote" begin + x = DualVector([5,6], I(2)) + y = [7,8] + @test [x,y] isa Vector{<:DualVector} + end +end