Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version = "0.1.0"
[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
39 changes: 39 additions & 0 deletions examples/componentarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using ComponentArrays, DualArrays


function denselayer(layer, x)
layer.weight * x + layer.bias
end



# We want to different with respect to b the following function:

function foo(b)
denselayer(ComponentVector(weight = [1 2; 3 4], bias = b), [5,6])
end

# This works as-is:

foo([5,6])
foo(DualVector([5,6], I(2)))

b = DualVector([5,6], I(2))
w = DualArray([1 2; 3 4], FourTensorIdentity((2,2), (2,2)))
ComponentVector(weight = w, bias = b)


denselayer(ComponentVector(weight = [1 2; 3 4], bias = b), [5,6])

f = x -> x * derivative(y -> x + y, 1)


g = (x,y) -> x + y
f = x -> x * g(x, 1)



function f(a, b)
a + b
end

26 changes: 23 additions & 3 deletions src/DualArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@ module DualArrays
export DualVector

import Base: +, ==, getindex, size, broadcast, axes, broadcasted, show, sum,
vcat, convert, *
vcat, convert, *, promote_rule
using LinearAlgebra, ArrayLayouts, BandedMatrices, FillArrays

import FillArrays: OneElementVector

struct Dual{T, Partials <: AbstractVector{T}} <: Real
value::T
partials::Partials
end


==(a::Dual, b::Dual) = a.value == b.value && a.partials == b.partials



sparse_getindex(a...) = layout_getindex(a...)
sparse_getindex(D::Diagonal, k::Integer, ::Colon) = OneElement(D.diag[k], k, size(D,2))
sparse_getindex(D::Diagonal, ::Colon, j::Integer) = OneElement(D.diag[j], j, size(D,1))
Expand All @@ -24,7 +29,10 @@ reprents a vector of duals given by
For now the entries just return the values.
"""

struct DualVector{T, M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}}
partials_type(::Type{Matrix{T}}) where T = Vector{T}
partials_type(::Type{<:Diagonal{T}}) where T = OneElementVector{T}

struct DualVector{T, M <: AbstractMatrix{T}, Partials} <: AbstractVector{Dual{T, Partials}}
value::Vector{T}
jacobian::M
function DualVector(value::Vector{T},jacobian::M) where {T, M <: AbstractMatrix{T}}
Expand All @@ -34,7 +42,7 @@ struct DualVector{T, M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}}
vector length: $x \n
no. of jacobian rows: $y"))
end
new{T,M}(value,jacobian)
new{T,M,partials_type(M)}(value,jacobian)
end
end

Expand All @@ -44,6 +52,18 @@ function DualVector(value::AbstractVector, jacobian::AbstractMatrix)
end


####
# convert/promotion
####

# TODO: convert_eltype for M
promote_rule(::Type{Vector{T}}, ::Type{DualVector{V,M}}) where {T,V,M} = DualVector{promote_type(T,V),M}
convert(::Type{DualVector{V,M}}, x::AbstractVector{T}) where {T,V,M} = DualVector(convert(AbstractVector{V}, x), convert(M, Zeros(axes(x,1), axes(x,1))))


#####
# getindex/setindex!
####

function getindex(x::DualVector, y::Int)
Dual(x.value[y], sparse_getindex(x.jacobian,y,:))
Expand Down
10 changes: 8 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DualArrays, Test, LinearAlgebra, ForwardDiff, BandedMatrices
using DualArrays, Test, LinearAlgebra, ForwardDiff, BandedMatrices, ComponentArrays
using DualArrays: Dual

@testset "DualArrays" begin
Expand Down Expand Up @@ -39,4 +39,10 @@ using DualArrays: Dual
@test vcat(x) == x
@test vcat(x, x) == DualVector([1, 1], [1 2 3;1 2 3])
@test vcat(x, y) == DualVector([1, 2, 3], [1 2 3;4 5 6;7 8 9])
end

@testset "promote" begin
x = DualVector([5,6], I(2))
y = [7,8]
@test [x,y] isa Vector{<:DualVector}
end
end