Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend gradlogpdf to MixtureModels #1827

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/src/mixture.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ var(::UnivariateMixture)
length(::MultivariateMixture)
pdf(::AbstractMixtureModel, ::Any)
logpdf(::AbstractMixtureModel, ::Any)
gradlogpdf(::AbstractMixtureModel, ::Any)
rand(::AbstractMixtureModel)
rand!(::AbstractMixtureModel, ::AbstractArray)
```
Expand Down
1 change: 1 addition & 0 deletions docs/src/truncate.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ are defined for all truncated univariate distributions:
- [`insupport(::UnivariateDistribution, x::Any)`](@ref)
- [`pdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logpdf(::UnivariateDistribution, ::Real)`](@ref)
- [`gradlogpdf(::UnivariateDistribution, ::Real)`](@ref)
- [`cdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logcdf(::UnivariateDistribution, ::Real)`](@ref)
- [`logdiffcdf(::UnivariateDistribution, ::T, ::T) where {T <: Real}`](@ref)
Expand Down
1 change: 1 addition & 0 deletions docs/src/univariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pdfsquaredL2norm
insupport(::UnivariateDistribution, x::Any)
pdf(::UnivariateDistribution, ::Real)
logpdf(::UnivariateDistribution, ::Real)
gradlogpdf(::UnivariateDistribution, ::Real)
loglikelihood(::UnivariateDistribution, ::AbstractArray)
cdf(::UnivariateDistribution, ::Real)
logcdf(::UnivariateDistribution, ::Real)
Expand Down
55 changes: 55 additions & 0 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ Here, `x` can be a single sample or an array of multiple samples.
"""
logpdf(d::AbstractMixtureModel, x::Any)

"""
gradlogpdf(d::Union{UnivariateMixture, MultivariateMixture}, x)

Evaluate the gradient of the logarithm of the (mixed) probability density function over `x`.
Here, `x` can be a single sample or an array of multiple samples.
"""
gradlogpdf(d::AbstractMixtureModel, x::Any)

"""
rand(d::Union{UnivariateMixture, MultivariateMixture})

Expand Down Expand Up @@ -362,6 +370,29 @@ end
pdf(d::UnivariateMixture, x::Real) = _mixpdf1(d, x)
logpdf(d::UnivariateMixture, x::Real) = _mixlogpdf1(d, x)

function gradlogpdf(d::UnivariateMixture, x::Real)
cp = components(d)
pr = probs(d)
pdfx1 = pdf(cp[1], x)
pdfx = pr[1] * pdfx1
_glp = pdfx * gradlogpdf(cp[1], x)
glp = (!iszero(pr[1])) && (!iszero(pdfx)) ? _glp : zero(_glp)
@inbounds for i in Iterators.drop(eachindex(pr, cp), 1)
rmsrosa marked this conversation as resolved.
Show resolved Hide resolved
if !iszero(pr[i])
pdfxi = pdf(cp[i], x)
if !iszero(pdfxi)
pipdfxi = pr[i] * pdfxi
pdfx += pipdfxi
glp += pipdfxi * gradlogpdf(cp[i], x)
end
end
end
if !iszero(pdfx) # else glp is already zero
glp /= pdfx
end
Comment on lines +398 to +400
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct to return a gradlogpdf of zero if x is not in the support of the mixture distribution?

Copy link
Author

@rmsrosa rmsrosa Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered about that, but decided to follow the current behavior already implemented. For example:

julia> insupport(Beta(0.5, 0.5), -1)
false

julia> logpdf(Beta(0.5, 0.5), -1)
-Inf

julia> gradlogpdf(Beta(0.5, 0.5), -1)
0.0

I don't know. If it is constant -Inf, then the derivative is zero (except that (-Inf) - (-Inf) is not defined, but what matters is that the rate of change is zero...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should make this a separate issue and hopefully standardize the behavior.

return glp
end

_pdf!(r::AbstractArray{<:Real}, d::UnivariateMixture{Discrete}, x::UnitRange) = _mixpdf!(r, d, x)
_pdf!(r::AbstractArray{<:Real}, d::UnivariateMixture, x::AbstractArray{<:Real}) = _mixpdf!(r, d, x)
_logpdf!(r::AbstractArray{<:Real}, d::UnivariateMixture, x::AbstractArray{<:Real}) = _mixlogpdf!(r, d, x)
Expand All @@ -371,6 +402,30 @@ _logpdf(d::MultivariateMixture, x::AbstractVector{<:Real}) = _mixlogpdf1(d, x)
_pdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixpdf!(r, d, x)
_logpdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixlogpdf!(r, d, x)

function gradlogpdf(d::MultivariateMixture, x::AbstractVector{<:Real})
cp = components(d)
pr = probs(d)
pdfx1 = pdf(cp[1], x)
pdfx = pr[1] * pdfx1
glp = pdfx * gradlogpdf(cp[1], x)
if ( iszero(pr[1]) || iszero(pdfx) )
glp .= zero(eltype(glp))
rmsrosa marked this conversation as resolved.
Show resolved Hide resolved
end
@inbounds for i in Iterators.drop(eachindex(pr, cp), 1)
if !iszero(pr[i])
pdfxi = pdf(cp[i], x)
if !iszero(pdfxi)
pipdfxi = pr[i] * pdfxi
pdfx += pipdfxi
glp .+= pipdfxi * gradlogpdf(cp[i], x)
end
end
end
if !iszero(pdfx) # else glp is already zero
glp ./= pdfx
end
return glp
end

## component-wise pdf and logpdf

Expand Down
44 changes: 44 additions & 0 deletions test/gradlogpdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,47 @@ using Test
[0.191919191919192, 1.080808080808081] ,atol=1.0e-8)
@test isapprox(gradlogpdf(MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.]), [0.7, 0.9]),
[0.2150711513583442, 1.2111901681759383] ,atol=1.0e-8)

# Test for gradlogpdf on univariate mixture distributions against centered finite-difference on logpdf

x = [-0.2, 0.3, 0.8, 1.0, 1.3, 10.5]
delta = 0.0001

for d in (
MixtureModel([Normal(-4.5, 2.0)], [1.0]),
MixtureModel([Exponential(2.0)], [1.0]),
MixtureModel([Uniform(-1.0, 1.0)], [1.0]),
MixtureModel([Normal(1//1, 2//1), Beta(2//1, 3//1), Exponential(3//2)], [3//10, 4//10, 3//10]),
MixtureModel([Normal(-2.0, 3.5), Normal(-4.5, 2.0)], [0.0, 1.0]),
MixtureModel([Beta(1.5, 3.0), Chi(5.0), Chisq(7.0)], [0.4, 0.3, 0.3]),
MixtureModel([Exponential(2.0), Gamma(9.0, 0.5), Gumbel(3.5, 1.0), Laplace(7.0)], [0.3, 0.2, 0.4, 0.1]),
MixtureModel([Logistic(-6.0), LogNormal(5.5), TDist(8.0), Weibull(2.0)], [0.3, 0.2, 0.4, 0.1])
)
xs = filter(s -> all(insupport.(d, [s - delta, s, s + delta])), x)
glp1 = gradlogpdf.(d, xs)
glp2 = ( logpdf.(d, xs .+ delta) - logpdf.(d, xs .- delta) ) ./ 2delta
@info "Testing `gradlogpdf` on $d"
@test isapprox(glp1, glp2, atol = 0.01)
end

# Test for gradlogpdf on multivariate mixture distributions against centered finite-difference on logpdf

x = [[0.2, 0.3], [0.8, 1.3], [-1.0, 10.5]]
delta = 0.001

for d in (
MixtureModel([MvNormal([1., 2.], [1. 0.1; 0.1 1.])], [1.0]),
MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvNormal([2.0, 1.0], [0.3 0.1; 0.1 0.4])], [0.4, 0.6]),
MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvNormal([2.0, 1.0], [0.3 0.1; 0.1 0.4])], [1.0, 0.0]),
MixtureModel([MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [1.0]),
MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [0.4, 0.6])
)
xs = filter(s -> insupport(d, s), x)
for xi in xs
glp = gradlogpdf(d, xi)
glpx = ( logpdf(d, xi .+ [delta, 0]) - logpdf(d, xi .- [delta, 0]) ) ./ 2delta
glpy = ( logpdf(d, xi .+ [0, delta]) - logpdf(d, xi .- [0, delta]) ) ./ 2delta
@test isapprox(glp[1], glpx, atol=delta)
@test isapprox(glp[2], glpy, atol=delta)
end
end
Loading