From ca6808208604497890f4ecc835be0a51022030e4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 17 Dec 2024 00:24:57 +0100 Subject: [PATCH 1/4] Fix von Mises-Fisher sampler --- src/samplers/vonmisesfisher.jl | 139 +++++++++++++++++++--------- src/univariate/continuous/gamma.jl | 1 - test/multivariate/vonmisesfisher.jl | 42 ++++++--- 3 files changed, 126 insertions(+), 56 deletions(-) diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index fd3eb2df08..53dc23d3a0 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -1,4 +1,6 @@ # Sampler for von Mises-Fisher +# Ref https://doi.org/10.18637/jss.v058.i10 +# Ref https://hal.science/hal-04004568v3 struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous} p::Int # the dimension κ::Float64 @@ -6,29 +8,32 @@ struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous} x0::Float64 c::Float64 v::Vector{Float64} + rotate::Bool # whether to rotate the samples end function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) + # Step 1: Calculate b, x₀, and c p = length(μ) - b = _vmf_bval(p, κ) + b = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1))) x0 = (1.0 - b) / (1.0 + b) c = κ * x0 + (p - 1) * log1p(-abs2(x0)) - v = _vmf_householder_vec(μ) - VonMisesFisherSampler(p, κ, b, x0, c, v) + + # Compute Householder transformation, and whether it has to be applied + v, rotate = _vmf_householder_vec(μ) + + return VonMisesFisherSampler(p, κ, b, x0, c, v, rotate) end Base.length(s::VonMisesFisherSampler) = length(s.v) -@inline function _vmf_rot!(v::AbstractVector, x::AbstractVector) - # rotate - scale = 2.0 * (v' * x) - @. x -= (scale * v) - return x -end +function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector{<:Real}) + # TODO: Generalize to more general indices + Base.require_one_based_indexing(x) + # Sample angle `w` + w = _vmf_angle(rng, spl) -function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) - w = _vmf_genw(rng, spl) + # Generate sample assuming `μ = (1, 0, 0, ..., 0)` p = spl.p x[1] = w s = 0.0 @@ -43,47 +48,81 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) x[i] *= r end - return _vmf_rot!(spl.v, x) + # Rotate for general `μ` (if necessary) + return _vmf_rotate!(x, spl) end ### Core computation -_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1))) - -function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ) - ξ = rand(rng) - w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) - return w::Float64 -end - -function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ) - r = (p - 1) / 2.0 - betad = Beta(r, r) - z = rand(rng, betad) - w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) - while κ * w + (p - 1) * log(1 - x0 * w) - c < log(rand(rng)) - z = rand(rng, betad) - w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) - end - return w::Float64 -end +# Step 2: Sample angle W +function _vmf_angle(rng::AbstractRNG, spl::VonMisesFisherSampler) + p = spl.p + κ = spl.κ -# generate the W value -- the key step in simulating vMF -# -# following movMF's document for the p != 3 case -# and Wenzel Jakob's document for the p == 3 case -function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) if p == 3 - return _vmf_genw3(rng, p, b, x0, c, κ) + _vmf_angle3(rng, κ) else - return _vmf_genwp(rng, p, b, x0, c, κ) + # General case: Rejection sampling + # Ref https://doi.org/10.18637/jss.v058.i10 + b = spl.b + c = spl.c + p = spl.p + κ = spl.κ + x0 = spl.x0 + pm1 = p - 1 + + if p == 2 + # In this case the distribution reduces to the von Mises distribution on the circle + # We exploit the fact that `Beta(1/2, 1/2) = Arcsine(0, 1)` + dist = Arcsine(zero(b), one(b)) + while true + z = rand(rng, dist) + w = (1 - (1 + b) * z) / (1 - (1 - b) * z) + if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng) + return w::Float64 + end + end + else + # We sample from a `Beta((p - 1)/2, (p - 1)/2)` distribution, possibly repeatedly + # Therefore we construct a sampler + # To avoid the type instability of `sampler(Beta(...))` and `sampler(Gamma(...))` + # we directly construct the Gamma sampler for Gamma((p - 1)/2, 1) + # Since (p - 1)/2 > 1, we construct a `GammaMTSampler` + r = pm1 / 2 + gammasampler = GammaMTSampler(Gamma{typeof(r)}(r, one(r))) + while true + # w is supposed to be generated as + # z ~ Beta((p - 1)/ 2, (p - 1)/2) + # w = (1 - (1 + b) * z) / (1 - (1 - b) * z) + # We sample z as + # z1 ~ Gamma((p - 1) / 2, 1) + # z2 ~ Gamma((p - 1) / 2, 1) + # z = z1 / (z1 + z2) + # and rewrite the expression for w + # Cf. case p == 2 above + z1 = rand(rng, gammasampler) + z2 = rand(rng, gammasampler) + b_z1 = b * z1 + w = (z2 - b_z1) / (z2 + b_z1) + if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng) + return w::Float64 + end + end + end end end +# Special case: 2-sphere +@inline function _vmf_angle3(rng::AbstractRNG, κ::Real) + # In this case, we can directly sample the angle + # Ref https://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf + ξ = rand(rng) + w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) + return w::Float64 +end -_vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) = - _vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ) - +# Create Householder transformation to rotate samples for `μ = (1, 0, ..., 0)` +# to samples for general `μ` function _vmf_householder_vec(μ::Vector{Float64}) # assuming μ is a unit-vector (which it should be) # can compute v in a single pass over μ @@ -92,11 +131,27 @@ function _vmf_householder_vec(μ::Vector{Float64}) v = similar(μ) v[1] = μ[1] - 1.0 s = sqrt(-2*v[1]) + if iszero(s) + # In this case, μ is (approx.) (1, 0, ..., 0) + # Hence no rotation has to be performed and `v` is not used + return v, false + end + v[1] /= s @inbounds for i in 2:p v[i] = μ[i] / s end - return v + return v, true +end + +# Rotate samples for general `μ` (if needed) +@inline function _vmf_rotate!(x::AbstractVector{<:Real}, spl::VonMisesFisherSampler) + if spl.rotate + v = spl.v + scale = 2.0 * (v' * x) + @. x -= (scale * v) + end + return x end diff --git a/src/univariate/continuous/gamma.jl b/src/univariate/continuous/gamma.jl index 866255fb7d..8ba207d2c7 100644 --- a/src/univariate/continuous/gamma.jl +++ b/src/univariate/continuous/gamma.jl @@ -105,7 +105,6 @@ function rand(rng::AbstractRNG, d::Gamma) # TODO: shape(d) = 0.5 : use scaled chisq return rand(rng, GammaIPSampler(d)) elseif shape(d) == 1.0 - θ = return rand(rng, Exponential{partype(d)}(scale(d))) else return rand(rng, GammaMTSampler(d)) diff --git a/test/multivariate/vonmisesfisher.jl b/test/multivariate/vonmisesfisher.jl index cc45f41ed5..7dc6f90028 100644 --- a/test/multivariate/vonmisesfisher.jl +++ b/test/multivariate/vonmisesfisher.jl @@ -23,6 +23,7 @@ function gen_vmf_tdata(n::Int, p::Int, end function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing) + # Random μ if ismissing(rng) μ = randn(p) x = randn(p) @@ -34,16 +35,24 @@ function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing) μ = μ ./ κ s = Distributions.VonMisesFisherSampler(μ, κ) + @test s.rotate v = μ - vcat(1, zeros(p-1)) H = I - 2*v*v'/(v'*v) - @test Distributions._vmf_rot!(s.v, copy(x)) ≈ (H*x) - -end + @test Distributions._vmf_rotate!(copy(x), s) ≈ (H*x) + # Special case: μ = (1, 0, ..., 0) + # In this case no rotation is performed + μ = zeros(p) + μ[1] = 1 + s = Distributions.VonMisesFisherSampler(μ, κ) + @test !s.rotate + @test Distributions._vmf_rotate!(copy(x), s) == x + return nothing +end -function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) +function test_angle3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) p = 3 if ismissing(rng) @@ -53,21 +62,20 @@ function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missin end μ = μ ./ norm(μ) - s = Distributions.VonMisesFisherSampler(μ, float(κ)) + spl = Distributions.VonMisesFisherSampler(μ, float(κ)) + angle3_res = [Distributions._vmf_angle3(rng, spl.κ) for _ in 1:ns] + angle_res = [Distributions._vmf_angle(rng, spl) for _ in 1:ns] - genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns] - genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns] - - @test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01) - @test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ) + @test mean(angle3_res) ≈ mean(angle_res) rtol=5e-2 + @test std(angle3_res) ≈ std(angle_res) rtol=1e-2 # test mean and stdev against analytical formulas coth_κ = coth(κ) mean_w = coth_κ - 1/κ var_w = 1 - coth_κ^2 + 1/κ^2 - @test isapprox(mean(genw3_res), mean_w, atol=0.01) - @test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ) + @test mean(angle3_res) ≈ mean_w rtol=5e-2 + @test std(angle3_res) ≈ sqrt(var_w) rtol=1e-2 end @@ -178,7 +186,15 @@ ns = 10^6 if !ismissing(rng) @testset "Testing genw with $key at (3, $κ)" for κ in [0.1, 0.5, 1.0, 2.0, 5.0] - test_genw3(κ, ns, rng) + test_angle3(κ, ns, rng) end end end + +# issue #1423 +@testset "Special case: No rotation" begin + for n in 2:10 + d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0) + @test sum(abs2, rand(d)) ≈ 1 + end +end From ee6c463284b26aa28cb75211f442daa83c94a993 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 30 Mar 2025 00:20:01 +0100 Subject: [PATCH 2/4] Use `LinearAlgebra.reflector!` and `LinearAlgebra.reflectorApply!` --- src/samplers/vonmisesfisher.jl | 62 +++++++++-------------------- test/multivariate/vonmisesfisher.jl | 32 +-------------- 2 files changed, 20 insertions(+), 74 deletions(-) diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index 53dc23d3a0..31e6643ec1 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -7,8 +7,8 @@ struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous} b::Float64 x0::Float64 c::Float64 + τ::Float64 v::Vector{Float64} - rotate::Bool # whether to rotate the samples end function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) @@ -18,10 +18,17 @@ function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) x0 = (1.0 - b) / (1.0 + b) c = κ * x0 + (p - 1) * log1p(-abs2(x0)) - # Compute Householder transformation, and whether it has to be applied - v, rotate = _vmf_householder_vec(μ) + # Compute Householder transformation: + # `LinearAlgebra.reflector!` computes a Householder transformation H such that + # H μ = -copysign(|μ|₂, μ[1]) e₁ + # μ is a unit vector, and hence this implies that + # H e₁ = μ if μ[1] < 0 and H (-e₁) = μ otherwise + # Since `v[1] = flipsign(1, μ[1])`, the sign of `μ[1]` can be extracted from `v[1]` during sampling + v = similar(μ) + copyto!(v, μ) + τ = LinearAlgebra.reflector!(v) - return VonMisesFisherSampler(p, κ, b, x0, c, v, rotate) + return VonMisesFisherSampler(p, κ, b, x0, c, τ, v) end Base.length(s::VonMisesFisherSampler) = length(s.v) @@ -30,10 +37,14 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector{ # TODO: Generalize to more general indices Base.require_one_based_indexing(x) - # Sample angle `w` + # Sample angle `w` assuming mean direction `(1, 0, ..., 0)` w = _vmf_angle(rng, spl) + + # Transform to sample for mean direction `(flipsign(1.0, μ[1]), 0, ..., 0)` + v = spl.v + w = flipsign(w, v[1]) - # Generate sample assuming `μ = (1, 0, 0, ..., 0)` + # Generate sample assuming mean direction `(flipsign(1.0, μ[1]), 0, ..., 0)` p = spl.p x[1] = w s = 0.0 @@ -48,8 +59,8 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector{ x[i] *= r end - # Rotate for general `μ` (if necessary) - return _vmf_rotate!(x, spl) + # Apply Householder transformation to mean direction `μ` + return LinearAlgebra.reflectorApply!(v, spl.τ, x) end ### Core computation @@ -120,38 +131,3 @@ end w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) return w::Float64 end - -# Create Householder transformation to rotate samples for `μ = (1, 0, ..., 0)` -# to samples for general `μ` -function _vmf_householder_vec(μ::Vector{Float64}) - # assuming μ is a unit-vector (which it should be) - # can compute v in a single pass over μ - - p = length(μ) - v = similar(μ) - v[1] = μ[1] - 1.0 - s = sqrt(-2*v[1]) - if iszero(s) - # In this case, μ is (approx.) (1, 0, ..., 0) - # Hence no rotation has to be performed and `v` is not used - return v, false - end - - v[1] /= s - - @inbounds for i in 2:p - v[i] = μ[i] / s - end - - return v, true -end - -# Rotate samples for general `μ` (if needed) -@inline function _vmf_rotate!(x::AbstractVector{<:Real}, spl::VonMisesFisherSampler) - if spl.rotate - v = spl.v - scale = 2.0 * (v' * x) - @. x -= (scale * v) - end - return x -end diff --git a/test/multivariate/vonmisesfisher.jl b/test/multivariate/vonmisesfisher.jl index 7dc6f90028..59f87e332c 100644 --- a/test/multivariate/vonmisesfisher.jl +++ b/test/multivariate/vonmisesfisher.jl @@ -22,36 +22,6 @@ function gen_vmf_tdata(n::Int, p::Int, return X end -function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing) - # Random μ - if ismissing(rng) - μ = randn(p) - x = randn(p) - else - μ = randn(rng, p) - x = randn(rng, p) - end - κ = norm(μ) - μ = μ ./ κ - - s = Distributions.VonMisesFisherSampler(μ, κ) - @test s.rotate - v = μ - vcat(1, zeros(p-1)) - H = I - 2*v*v'/(v'*v) - - @test Distributions._vmf_rotate!(copy(x), s) ≈ (H*x) - - # Special case: μ = (1, 0, ..., 0) - # In this case no rotation is performed - μ = zeros(p) - μ[1] = 1 - s = Distributions.VonMisesFisherSampler(μ, κ) - @test !s.rotate - @test Distributions._vmf_rotate!(copy(x), s) == x - - return nothing -end - function test_angle3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) p = 3 @@ -181,7 +151,6 @@ ns = 10^6 (2, 2), (2, 1000)] # test with large κ test_vonmisesfisher(p, κ, n, ns, rng) - test_vmf_rot(p, rng) end if !ismissing(rng) @@ -196,5 +165,6 @@ end for n in 2:10 d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0) @test sum(abs2, rand(d)) ≈ 1 + @test normalize!(mean(rand(d) for _ in 1:1_000_000)) ≈ vcat(1, zeros(n - 1)) rtol = 1e-2 end end From 2bfca1ec3210a1597d645cc3503bc344b699be1c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 30 Mar 2025 00:39:46 +0100 Subject: [PATCH 3/4] Change test --- test/multivariate/vonmisesfisher.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/multivariate/vonmisesfisher.jl b/test/multivariate/vonmisesfisher.jl index 59f87e332c..193eeabb40 100644 --- a/test/multivariate/vonmisesfisher.jl +++ b/test/multivariate/vonmisesfisher.jl @@ -165,6 +165,7 @@ end for n in 2:10 d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0) @test sum(abs2, rand(d)) ≈ 1 - @test normalize!(mean(rand(d) for _ in 1:1_000_000)) ≈ vcat(1, zeros(n - 1)) rtol = 1e-2 + d_est = fit_mle(VonMisesFisher, rand(d, 100_000)) + @test d_est.μ ≈ meandir(d) rtol=5e-2 end end From 66ae6caf717a58e1ad700eb59e57b1a59c6e6678 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 30 Mar 2025 00:56:17 +0100 Subject: [PATCH 4/4] Update test/multivariate/vonmisesfisher.jl --- test/multivariate/vonmisesfisher.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/multivariate/vonmisesfisher.jl b/test/multivariate/vonmisesfisher.jl index 193eeabb40..c4f8b2859d 100644 --- a/test/multivariate/vonmisesfisher.jl +++ b/test/multivariate/vonmisesfisher.jl @@ -166,6 +166,6 @@ end d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0) @test sum(abs2, rand(d)) ≈ 1 d_est = fit_mle(VonMisesFisher, rand(d, 100_000)) - @test d_est.μ ≈ meandir(d) rtol=5e-2 + @test meandir(d_est) ≈ meandir(d) rtol=5e-2 end end