diff --git a/src/NestedSamplers.jl b/src/NestedSamplers.jl index 1c631fa8..829619ca 100644 --- a/src/NestedSamplers.jl +++ b/src/NestedSamplers.jl @@ -6,6 +6,7 @@ using Random: AbstractRNG, GLOBAL_RNG using AbstractMCMC using AbstractMCMC: AbstractSampler, + LogDensityModel, AbstractModel, samples, save!! diff --git a/src/staticsampler.jl b/src/staticsampler.jl index 1c7be396..f9e2384e 100644 --- a/src/staticsampler.jl +++ b/src/staticsampler.jl @@ -1,6 +1,7 @@ # Sampler and model implementations struct Nested{B, P <: AbstractProposal} <: AbstractSampler + type::Type ndims::Int nactive::Int bounds::B @@ -53,6 +54,7 @@ function Nested(ndims, enlarge = 1.25, min_ncall=2nactive, min_eff=0.10, + type = Float64, kwargs...) nactive < 2ndims && @warn "Using fewer than 2ndim ($(2ndims)) active points is discouraged" @@ -72,7 +74,9 @@ function Nested(ndims, update_interval_frac = get(kwargs, :update_interval, default_update_interval(proposal, ndims)) update_interval = round(Int, update_interval_frac * nactive) - return Nested(ndims, + return Nested( + type, + ndims, nactive, bounds, enlarge, diff --git a/src/step.jl b/src/step.jl index 8041bb17..e0d4c8d7 100644 --- a/src/step.jl +++ b/src/step.jl @@ -1,8 +1,43 @@ +#TODO it would be cool to make these parametric +struct NestedState + it::Int + ncall::Int + us + vs + logl + logl_dead + logz + logzerr + h + logvol + since_update::Int + has_bounds::Bool + active_bound +end + +struct NestedTransition + #θ + u + v + logwt + logl +end -function step(rng, model, sampler::Nested; kwargs...) +function step( + rng::Random.AbstractRNG, + model, + sampler::Nested; + init_params = nothing, + kwargs...) # Initialize particles # us are in unit space, vs are in prior space - us, vs, logl = init_particles(rng, model, sampler) + if init_params != nothing && size(init_params) == (sampler.ndims,sampler.nactive) + us = init_params + else + @warn "Init params have wrong dimensions, sampling new ones!" + us = rand(rng, sampler.type, sampler.ndims, sampler.nactive) + end + us, vs, logl = init_particles(rng, us, model) # Find least likely point logl_dead, idx_dead = findmin(logl) @@ -31,15 +66,20 @@ function step(rng, model, sampler::Nested; kwargs...) logzerr = sqrt(h / sampler.nactive) logvol -= 1 / sampler.nactive - sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) - state = (it = 1, ncall = ncall, us = us, vs = vs, logl = logl, logl_dead = logl_dead, - logz = logz, logzerr = logzerr, h = h, logvol = logvol, - since_update = since_update, has_bounds = false, active_bound = nothing) + sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) + #sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) + state = NestedState(1, ncall, us, vs, logl, logl_dead, + logz, logzerr, h, logvol, since_update, false, nothing) return sample, state end -function step(rng, model, sampler, state; kwargs...) +function step( + rng::Random.AbstractRNG, + model, + sampler::Nested, + state::NestedState; + kwargs...) ## Update bounds pointvol = exp(state.logvol) / sampler.nactive # check if ready for first update @@ -102,10 +142,10 @@ function step(rng, model, sampler, state; kwargs...) logvol = state.logvol - 1 / sampler.nactive ## prepare returns - sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) - state = (it = it, ncall = ncall, us = state.us, vs = state.vs, logl = state.logl, logl_dead = logl_dead, - logz = logz, logzerr = logzerr, h = h, logvol = logvol, - since_update = since_update, has_bounds = has_bounds, active_bound = active_bound) + sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) + #sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) + state = NestedState(it, ncall, state.us, state.vs, state.logl, logl_dead, + logz, logzerr, h, logvol, since_update, has_bounds, active_bound) return sample, state end @@ -113,7 +153,7 @@ end function bundle_samples(samples, model::AbstractModel, sampler::Nested, - state, + state::NestedState, ::Type{Chains}; add_live=true, param_names=missing, @@ -143,7 +183,7 @@ end function bundle_samples(samples, model::AbstractModel, sampler::Nested, - state, + state::NestedState, ::Type{Array}; add_live=true, check_wsum=true, @@ -175,8 +215,7 @@ init_particles(rng, model, sampler) = # loop and fill arrays, checking validity of points # will retry 100 times before erroring -function init_particles(rng, T, ndims, nactive, model) - us = rand(rng, T, ndims, nactive) +function init_particles(rng::Random.AbstractRNG, us::Matrix, model::AbstractModel) vs_and_logl = mapslices( Base.Fix1(prior_transform_and_loglikelihood, model), us; dims=1 @@ -204,6 +243,25 @@ function init_particles(rng, T, ndims, nactive, model) return us, vs, logl end +function init_particles(rng::Random.AbstractRNG, us::Matrix, model::LogDensityModel) + logdensity = model.logdensity + dims, nactive = size(us) + logls = [logdensity(us[:, i]) for i in 1:nactive] + + ntries = 1 + while true + any(isfinite, logls) && break + rand!(rng, us) + logls = [logdensity(us[:, i]) for i in 1:nactive] + ntries += 1 + ntries > 100 && error("After 100 attempts, could not initialize any live points with finite loglikelihood. Please check your prior transform and loglikelihood methods.") + end + + # force -Inf to be a finite but small number to keep estimators from breaking + @. logls[logls == -Inf] = -1e300 + + return us, us, logls +end # add remaining live points to `samples` function add_live_points(samples, model, sampler, state)