Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix von Mises-Fisher sampler #1930

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 88 additions & 57 deletions src/samplers/vonmisesfisher.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,50 @@
# Sampler for von Mises-Fisher
# Ref https://doi.org/10.18637/jss.v058.i10
# Ref https://hal.science/hal-04004568v3
struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous}
p::Int # the dimension
κ::Float64
b::Float64
x0::Float64
c::Float64
τ::Float64
v::Vector{Float64}
end

function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64)
# Step 1: Calculate b, x₀, and c
p = length(μ)
b = _vmf_bval(p, κ)
b = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))
x0 = (1.0 - b) / (1.0 + b)
c = κ * x0 + (p - 1) * log1p(-abs2(x0))
v = _vmf_householder_vec(μ)
VonMisesFisherSampler(p, κ, b, x0, c, v)

# Compute Householder transformation:
# `LinearAlgebra.reflector!` computes a Householder transformation H such that
# H μ = -copysign(|μ|₂, μ[1]) e₁
# μ is a unit vector, and hence this implies that
# H e₁ = μ if μ[1] < 0 and H (-e₁) = μ otherwise
# Since `v[1] = flipsign(1, μ[1])`, the sign of `μ[1]` can be extracted from `v[1]` during sampling
v = similar(μ)
copyto!(v, μ)
τ = LinearAlgebra.reflector!(v)

return VonMisesFisherSampler(p, κ, b, x0, c, τ, v)
end

Base.length(s::VonMisesFisherSampler) = length(s.v)

@inline function _vmf_rot!(v::AbstractVector, x::AbstractVector)
# rotate
scale = 2.0 * (v' * x)
@. x -= (scale * v)
return x
end
function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector{<:Real})
# TODO: Generalize to more general indices
Base.require_one_based_indexing(x)

# Sample angle `w` assuming mean direction `(1, 0, ..., 0)`
w = _vmf_angle(rng, spl)

# Transform to sample for mean direction `(flipsign(1.0, μ[1]), 0, ..., 0)`
v = spl.v
w = flipsign(w, v[1])

function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector)
w = _vmf_genw(rng, spl)
# Generate sample assuming mean direction `(flipsign(1.0, μ[1]), 0, ..., 0)`
p = spl.p
x[1] = w
s = 0.0
Expand All @@ -43,60 +59,75 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector)
x[i] *= r
end

return _vmf_rot!(spl.v, x)
# Apply Householder transformation to mean direction `μ`
return LinearAlgebra.reflectorApply!(v, spl.τ, x)
end

### Core computation

_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))

function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ)
ξ = rand(rng)
w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ)
return w::Float64
end

function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ)
r = (p - 1) / 2.0
betad = Beta(r, r)
z = rand(rng, betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
while κ * w + (p - 1) * log(1 - x0 * w) - c < log(rand(rng))
z = rand(rng, betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
end
return w::Float64
end
# Step 2: Sample angle W
function _vmf_angle(rng::AbstractRNG, spl::VonMisesFisherSampler)
p = spl.p
κ = spl.κ

# generate the W value -- the key step in simulating vMF
#
# following movMF's document for the p != 3 case
# and Wenzel Jakob's document for the p == 3 case
function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ)
if p == 3
return _vmf_genw3(rng, p, b, x0, c, κ)
_vmf_angle3(rng, κ)
else
return _vmf_genwp(rng, p, b, x0, c, κ)
# General case: Rejection sampling
# Ref https://doi.org/10.18637/jss.v058.i10
b = spl.b
c = spl.c
p = spl.p
κ = spl.κ
x0 = spl.x0
pm1 = p - 1

if p == 2
# In this case the distribution reduces to the von Mises distribution on the circle
# We exploit the fact that `Beta(1/2, 1/2) = Arcsine(0, 1)`
dist = Arcsine(zero(b), one(b))
while true
z = rand(rng, dist)
w = (1 - (1 + b) * z) / (1 - (1 - b) * z)
if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng)
return w::Float64
end
end
else
# We sample from a `Beta((p - 1)/2, (p - 1)/2)` distribution, possibly repeatedly
# Therefore we construct a sampler
# To avoid the type instability of `sampler(Beta(...))` and `sampler(Gamma(...))`
# we directly construct the Gamma sampler for Gamma((p - 1)/2, 1)
# Since (p - 1)/2 > 1, we construct a `GammaMTSampler`
r = pm1 / 2
gammasampler = GammaMTSampler(Gamma{typeof(r)}(r, one(r)))
while true
# w is supposed to be generated as
# z ~ Beta((p - 1)/ 2, (p - 1)/2)
# w = (1 - (1 + b) * z) / (1 - (1 - b) * z)
# We sample z as
# z1 ~ Gamma((p - 1) / 2, 1)
# z2 ~ Gamma((p - 1) / 2, 1)
# z = z1 / (z1 + z2)
# and rewrite the expression for w
# Cf. case p == 2 above
z1 = rand(rng, gammasampler)
z2 = rand(rng, gammasampler)
b_z1 = b * z1
w = (z2 - b_z1) / (z2 + b_z1)
if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng)
return w::Float64
end
end
end
end
end


_vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) =
_vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ)

function _vmf_householder_vec(μ::Vector{Float64})
# assuming μ is a unit-vector (which it should be)
# can compute v in a single pass over μ

p = length(μ)
v = similar(μ)
v[1] = μ[1] - 1.0
s = sqrt(-2*v[1])
v[1] /= s

@inbounds for i in 2:p
v[i] = μ[i] / s
end

return v
# Special case: 2-sphere
@inline function _vmf_angle3(rng::AbstractRNG, κ::Real)
# In this case, we can directly sample the angle
# Ref https://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf
ξ = rand(rng)
w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ)
return w::Float64
end
1 change: 0 additions & 1 deletion src/univariate/continuous/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ function rand(rng::AbstractRNG, d::Gamma)
# TODO: shape(d) = 0.5 : use scaled chisq
return rand(rng, GammaIPSampler(d))
elseif shape(d) == 1.0
θ =
return rand(rng, Exponential{partype(d)}(scale(d)))
else
return rand(rng, GammaMTSampler(d))
Expand Down
51 changes: 19 additions & 32 deletions test/multivariate/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,7 @@ function gen_vmf_tdata(n::Int, p::Int,
return X
end

function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing)
if ismissing(rng)
μ = randn(p)
x = randn(p)
else
μ = randn(rng, p)
x = randn(rng, p)
end
κ = norm(μ)
μ = μ ./ κ

s = Distributions.VonMisesFisherSampler(μ, κ)
v = μ - vcat(1, zeros(p-1))
H = I - 2*v*v'/(v'*v)

@test Distributions._vmf_rot!(s.v, copy(x)) ≈ (H*x)

end



function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing)
function test_angle3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing)
p = 3

if ismissing(rng)
Expand All @@ -53,21 +32,20 @@ function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missin
end
μ = μ ./ norm(μ)

s = Distributions.VonMisesFisherSampler(μ, float(κ))
spl = Distributions.VonMisesFisherSampler(μ, float(κ))
angle3_res = [Distributions._vmf_angle3(rng, spl.κ) for _ in 1:ns]
angle_res = [Distributions._vmf_angle(rng, spl) for _ in 1:ns]

genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]
genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]

@test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01)
@test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ)
@test mean(angle3_res) ≈ mean(angle_res) rtol=5e-2
@test std(angle3_res) ≈ std(angle_res) rtol=1e-2

# test mean and stdev against analytical formulas
coth_κ = coth(κ)
mean_w = coth_κ - 1/κ
var_w = 1 - coth_κ^2 + 1/κ^2

@test isapprox(mean(genw3_res), mean_w, atol=0.01)
@test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ)
@test mean(angle3_res) ≈ mean_w rtol=5e-2
@test std(angle3_res) ≈ sqrt(var_w) rtol=1e-2
end


Expand Down Expand Up @@ -173,12 +151,21 @@ ns = 10^6
(2, 2),
(2, 1000)] # test with large κ
test_vonmisesfisher(p, κ, n, ns, rng)
test_vmf_rot(p, rng)
end

if !ismissing(rng)
@testset "Testing genw with $key at (3, $κ)" for κ in [0.1, 0.5, 1.0, 2.0, 5.0]
test_genw3(κ, ns, rng)
test_angle3(κ, ns, rng)
end
end
end

# issue #1423
@testset "Special case: No rotation" begin
for n in 2:10
d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0)
@test sum(abs2, rand(d)) ≈ 1
d_est = fit_mle(VonMisesFisher, rand(d, 100_000))
@test meandir(d_est) ≈ meandir(d) rtol=5e-2
end
end
Loading