Skip to content

Commit 4b820f5

Browse files
Transfer MarginalLogDensities extension from Turing
See: TuringLang/Turing.jl#2664 Co-authored-by: Sam Urmy <[email protected]>
1 parent 7249158 commit 4b820f5

File tree

6 files changed

+71
-4
lines changed

6 files changed

+71
-4
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3434
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3535
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3636
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
37+
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
3738
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3839

3940
[extensions]
4041
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
4142
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
4243
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4344
DynamicPPLJETExt = ["JET"]
45+
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
4446
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4547
DynamicPPLMooncakeExt = ["Mooncake"]
4648

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
module DynamicPPLMarginalLogDensitiesExt
2+
3+
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
4+
using MarginalLogDensities: MarginalLogDensities
5+
6+
_to_varname(n::Symbol) = VarName{n}()
7+
_to_varname(n::VarName) = n
8+
9+
function DynamicPPL.marginalize(
10+
model::DynamicPPL.Model,
11+
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
12+
getlogprob=DynamicPPL.getlogjoint,
13+
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
14+
kwargs...,
15+
)
16+
# Determine the indices for the variables to marginalise out.
17+
varinfo = DynamicPPL.typed_varinfo(model)
18+
vns = map(_to_varname, varnames)
19+
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns))
20+
# Construct the marginal log-density model.
21+
# Use linked `varinfo` to that we're working in unconstrained space
22+
varinfo_linked = DynamicPPL.link(varinfo, model)
23+
24+
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked)
25+
mdl = MarginalLogDensities.MarginalLogDensity(
26+
(x, _) -> LogDensityProblems.logdensity(f, x),
27+
varinfo_linked[:],
28+
varindices,
29+
(),
30+
method;
31+
kwargs...,
32+
)
33+
return mdl
34+
end
35+
36+
end

src/DynamicPPL.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ export AbstractVarInfo,
122122
fix,
123123
unfix,
124124
predict,
125+
marginalize,
125126
prefix,
126127
returned,
127128
to_submodel,
@@ -199,10 +200,6 @@ include("test_utils.jl")
199200
include("experimental.jl")
200201
include("deprecated.jl")
201202

202-
if !isdefined(Base, :get_extension)
203-
using Requires
204-
end
205-
206203
# Better error message if users forget to load JET
207204
if isdefined(Base.Experimental, :register_error_hint)
208205
function __init__()
@@ -247,4 +244,7 @@ end
247244
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
248245
struct DynamicPPLTag end
249246

247+
# Extended in MarginalLogDensitiesExt
248+
function marginalize end
249+
250250
end # module

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1919
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2020
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
21+
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
2122
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2223
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2324
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module MarginalLogDensitiesExtTests
2+
3+
using DynamicPPL, Distributions, Test
4+
using MarginalLogDensities
5+
using ADTypes: AutoForwardDiff
6+
7+
@testset "MarginalLogDensities" begin
8+
# Simple test case.
9+
@model function demo()
10+
x ~ MvNormal(zeros(2), [1, 1])
11+
return y ~ Normal(0, 1)
12+
end
13+
model = demo()
14+
# Marginalize out `x`.
15+
16+
for vn in [@varname(x), :x]
17+
for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint]
18+
marginalized = marginalize(
19+
model, [vn], getlogprob; hess_adtype=AutoForwardDiff()
20+
)
21+
# Compute the marginal log-density of `y = 0.0`.
22+
@test marginalized([0.0]) logpdf(Normal(0, 1), 0.0) atol = 1e-5
23+
end
24+
end
25+
end
26+
27+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ include("test_util.jl")
8080
@testset "extensions" begin
8181
include("ext/DynamicPPLMCMCChainsExt.jl")
8282
include("ext/DynamicPPLJETExt.jl")
83+
include("ext/DynamicPPLMarginalLogDensitiesExt.jl")
8384
end
8485
@testset "ad" begin
8586
include("ext/DynamicPPLForwardDiffExt.jl")

0 commit comments

Comments
 (0)