Skip to content
Merged
29 changes: 21 additions & 8 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,27 +174,40 @@ end

"""
Diagonal(α, β)
Diagonal(size::Integer...)
Diagonal(size::Integer...; bias = true, init = ones32)

Create an element-wise linear layer, which performs

y = α .* x .+ β

The learnable arrays are initialised `α = ones(Float32, size)` and
`β = zeros(Float32, size)`.
if `bias` is true, and

y = α .* x

otherwise. The learnable arrays are initialised `α = ones(Float32, size)` and
`β = zeros(Float32, size)`. If `init` is specified, the function given to it is
called and used to initialise α. The weight matrix and/or the bias vector
(with the same size as x) may also be provided explicitly.

Used by [`LayerNorm`](@ref).
"""
struct Diagonal{T}
α::T
β::T
struct Diagonal{A, B}
α::A
β::B
function Diagonal(W::M, bias = true) where M<:AbstractArray
b = create_bias(W, bias, size(W)...)
new{M, typeof(b)}(W, b)
end
end

Diagonal(sz::Integer...) = Diagonal(ones32(sz...), zeros32(sz...))
Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias)

@functor Diagonal

(a::Diagonal)(x) = a.α .* x .+ a.β
function (a::Diagonal)(x)
x = a.α .* x
x = x .+ a.β
end

function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", join(size(l.α), ", "), ")")
Expand Down
7 changes: 4 additions & 3 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,17 @@ import Flux: activations
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))

@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does these tests need bias=false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don't really, I kept those just to check that bias=false doesn't trip anything


@test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2,3,4)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2,3)(rand(2,1,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4)
end

@testset "Maxout" begin
Expand Down