diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 2db4724c..59c0a06a 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -62,8 +62,8 @@ function productmeasure(f::Returns{FB}, param_maps, pars) where {FB<:FactoredBas end function productmeasure(f::Returns{W}, ::typeof(identity), pars) where {W<:WeightedMeasure} - ℓ = f.value.logweight - base = f.value.base + ℓ = _logweight(f.value) + base = basemeasure(f.value) newbase = productmeasure(Returns(base), identity, pars) weightedmeasure(length(pars) * ℓ, newbase) end @@ -102,7 +102,7 @@ function weightedmeasure(ℓ::R, b::M) where {R,M} end function weightedmeasure(ℓ, b::WeightedMeasure) - weightedmeasure(ℓ + b.logweight, b.base) + weightedmeasure(ℓ + _logweight(b), b.base) end ############################################################################### diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index 827d30fd..b364da7f 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -9,13 +9,17 @@ export WeightedMeasure, AbstractWeightedMeasure abstract type AbstractWeightedMeasure <: AbstractMeasure end -logweight(μ::AbstractWeightedMeasure) = μ.logweight -basemeasure(μ::AbstractWeightedMeasure) = μ.base +# By default the weight for all measure is 1 +_logweight(::AbstractMeasure) = 0 -@inline function logdensity_def(d::AbstractWeightedMeasure, x) +@inline function logdensity_def(d::AbstractWeightedMeasure, _) d.logweight end +function Base.rand(rng::AbstractRNG, ::Type{T}, μ::AbstractWeightedMeasure) where {T} + rand(rng, T, basemeasure(μ)) +end + ############################################################################### struct WeightedMeasure{R,M} <: AbstractWeightedMeasure @@ -23,6 +27,9 @@ struct WeightedMeasure{R,M} <: AbstractWeightedMeasure base::M end +_logweight(μ::WeightedMeasure) = μ.logweight +basemeasure(μ::AbstractWeightedMeasure) = μ.base + function Base.show(io::IO, μ::WeightedMeasure) io = IOContext(io, :compact => true) print(io, exp(μ.logweight), " * ", μ.base) diff --git a/test/combinators/weighted.jl b/test/combinators/weighted.jl new file mode 100644 index 00000000..5d9d48af --- /dev/null +++ b/test/combinators/weighted.jl @@ -0,0 +1,20 @@ +using Random: MersenneTwister +using Test + +using MeasureBase +using MeasureBase: _logweight, weightedmeasure, WeightedMeasure + +@testset "weighted" begin + @test iszero(_logweight(Lebesgue(ℝ))) + μ₀ = Dirac(0.0) + w = 2.0 + μ = @inferred w * μ₀ + @test μ == WeightedMeasure(log(w), μ₀) == weightedmeasure(log(w), μ₀) + @test μ isa WeightedMeasure + @test _logweight(μ) == log(w) + @test _logweight(w * μ) == 2 * log(w) + @test rand(MersenneTwister(123), μ) == rand(MersenneTwister(123), μ₀) + x = rand() + @test logdensity_def(μ, x) == log(w) + @test logdensityof(μ, x) == logdensityof(μ₀, x) +end diff --git a/test/runtests.jl b/test/runtests.jl index eab6722b..d8ab3534 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -212,3 +212,5 @@ end # @test f(x) ≈ x^2 # end end + +include("combinators/weighted.jl")