Skip to content

Turing Interface #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/NestedSamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Random: AbstractRNG, GLOBAL_RNG

using AbstractMCMC
using AbstractMCMC: AbstractSampler,
LogDensityModel,
AbstractModel,
samples,
save!!
Expand Down
6 changes: 5 additions & 1 deletion src/staticsampler.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Sampler and model implementations

struct Nested{B, P <: AbstractProposal} <: AbstractSampler
type::Type
ndims::Int
nactive::Int
bounds::B
Expand Down Expand Up @@ -53,6 +54,7 @@ function Nested(ndims,
enlarge = 1.25,
min_ncall=2nactive,
min_eff=0.10,
type = Float64,
kwargs...)

nactive < 2ndims && @warn "Using fewer than 2ndim ($(2ndims)) active points is discouraged"
Expand All @@ -72,7 +74,9 @@ function Nested(ndims,

update_interval_frac = get(kwargs, :update_interval, default_update_interval(proposal, ndims))
update_interval = round(Int, update_interval_frac * nactive)
return Nested(ndims,
return Nested(
type,
ndims,
nactive,
bounds,
enlarge,
Expand Down
88 changes: 73 additions & 15 deletions src/step.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,43 @@
#TODO it would be cool to make these parametric
struct NestedState
it::Int
ncall::Int
us
vs
logl
logl_dead
logz
logzerr
h
logvol
since_update::Int
has_bounds::Bool
active_bound
end

struct NestedTransition
u
v
logwt
logl
end

function step(rng, model, sampler::Nested; kwargs...)
function step(
rng::Random.AbstractRNG,
model,
sampler::Nested;
init_params = nothing,
kwargs...)
# Initialize particles
# us are in unit space, vs are in prior space
us, vs, logl = init_particles(rng, model, sampler)
if init_params != nothing && size(init_params) == (sampler.ndims,sampler.nactive)
us = init_params
else
@warn "Init params have wrong dimensions, sampling new ones!"
us = rand(rng, sampler.type, sampler.ndims, sampler.nactive)
end
us, vs, logl = init_particles(rng, us, model)

# Find least likely point
logl_dead, idx_dead = findmin(logl)
Expand Down Expand Up @@ -31,15 +66,20 @@ function step(rng, model, sampler::Nested; kwargs...)
logzerr = sqrt(h / sampler.nactive)
logvol -= 1 / sampler.nactive

sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead)
state = (it = 1, ncall = ncall, us = us, vs = vs, logl = logl, logl_dead = logl_dead,
logz = logz, logzerr = logzerr, h = h, logvol = logvol,
since_update = since_update, has_bounds = false, active_bound = nothing)
sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead)
#sample = NestedTransition(u_dead, v_dead, logwt, logl_dead)
state = NestedState(1, ncall, us, vs, logl, logl_dead,
logz, logzerr, h, logvol, since_update, false, nothing)

return sample, state
end

function step(rng, model, sampler, state; kwargs...)
function step(
rng::Random.AbstractRNG,
model,
sampler::Nested,
state::NestedState;
kwargs...)
## Update bounds
pointvol = exp(state.logvol) / sampler.nactive
# check if ready for first update
Expand Down Expand Up @@ -102,18 +142,18 @@ function step(rng, model, sampler, state; kwargs...)
logvol = state.logvol - 1 / sampler.nactive

## prepare returns
sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead)
state = (it = it, ncall = ncall, us = state.us, vs = state.vs, logl = state.logl, logl_dead = logl_dead,
logz = logz, logzerr = logzerr, h = h, logvol = logvol,
since_update = since_update, has_bounds = has_bounds, active_bound = active_bound)
sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead)
#sample = NestedTransition(u_dead, v_dead, logwt, logl_dead)
state = NestedState(it, ncall, state.us, state.vs, state.logl, logl_dead,
logz, logzerr, h, logvol, since_update, has_bounds, active_bound)

return sample, state
end

function bundle_samples(samples,
model::AbstractModel,
sampler::Nested,
state,
state::NestedState,
::Type{Chains};
add_live=true,
param_names=missing,
Expand Down Expand Up @@ -143,7 +183,7 @@ end
function bundle_samples(samples,
model::AbstractModel,
sampler::Nested,
state,
state::NestedState,
::Type{Array};
add_live=true,
check_wsum=true,
Expand Down Expand Up @@ -175,8 +215,7 @@ init_particles(rng, model, sampler) =

# loop and fill arrays, checking validity of points
# will retry 100 times before erroring
function init_particles(rng, T, ndims, nactive, model)
us = rand(rng, T, ndims, nactive)
function init_particles(rng::Random.AbstractRNG, us::Matrix, model::AbstractModel)
vs_and_logl = mapslices(
Base.Fix1(prior_transform_and_loglikelihood, model), us;
dims=1
Expand Down Expand Up @@ -204,6 +243,25 @@ function init_particles(rng, T, ndims, nactive, model)
return us, vs, logl
end

function init_particles(rng::Random.AbstractRNG, us::Matrix, model::LogDensityModel)
logdensity = model.logdensity
dims, nactive = size(us)
logls = [logdensity(us[:, i]) for i in 1:nactive]

ntries = 1
while true
any(isfinite, logls) && break
rand!(rng, us)
logls = [logdensity(us[:, i]) for i in 1:nactive]
ntries += 1
ntries > 100 && error("After 100 attempts, could not initialize any live points with finite loglikelihood. Please check your prior transform and loglikelihood methods.")
end

# force -Inf to be a finite but small number to keep estimators from breaking
@. logls[logls == -Inf] = -1e300

return us, us, logls
end

# add remaining live points to `samples`
function add_live_points(samples, model, sampler, state)
Expand Down