|
| 1 | +module DynamicPPLMarginalLogDensitiesExt |
| 2 | + |
| 3 | +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName |
| 4 | +using MarginalLogDensities: MarginalLogDensities |
| 5 | + |
| 6 | +_to_varname(n::Symbol) = VarName{n}() |
| 7 | +_to_varname(n::VarName) = n |
| 8 | + |
| 9 | +function DynamicPPL.marginalize( |
| 10 | + model::DynamicPPL.Model, |
| 11 | + varnames::AbstractVector{<:Union{Symbol,<:VarName}}, |
| 12 | + getlogprob=DynamicPPL.getlogjoint, |
| 13 | + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); |
| 14 | + kwargs..., |
| 15 | +) |
| 16 | + # Determine the indices for the variables to marginalise out. |
| 17 | + varinfo = DynamicPPL.typed_varinfo(model) |
| 18 | + vns = map(_to_varname, varnames) |
| 19 | + varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) |
| 20 | + # Construct the marginal log-density model. |
| 21 | + # Use linked `varinfo` to that we're working in unconstrained space |
| 22 | + varinfo_linked = DynamicPPL.link(varinfo, model) |
| 23 | + |
| 24 | + f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked) |
| 25 | + mdl = MarginalLogDensities.MarginalLogDensity( |
| 26 | + (x, _) -> LogDensityProblems.logdensity(f, x), |
| 27 | + varinfo_linked[:], |
| 28 | + varindices, |
| 29 | + (), |
| 30 | + method; |
| 31 | + kwargs..., |
| 32 | + ) |
| 33 | + return mdl |
| 34 | +end |
| 35 | + |
| 36 | +end |
0 commit comments