From f5297529d8068f08e182837bf3a3aade55a4a9c5 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 15 Sep 2023 13:49:35 +0100 Subject: [PATCH 1/9] NestedState NestedTransition --- src/step.jl | 64 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/src/step.jl b/src/step.jl index 8041bb17..856e3197 100644 --- a/src/step.jl +++ b/src/step.jl @@ -1,5 +1,30 @@ -function step(rng, model, sampler::Nested; kwargs...) +#TODO it would be cool to make these parametric +struct NestedState + it::Int + ncall::Int + us + vs + logl + logl_dead + logz + logzerr + h + logvol + since_update::Int + has_bounds::Bool + active_bound::Bool +end + +struct NestedTransition + θ + u + v + logwt + logl +end + +function step(rng::Random.AbstractRNG, model, sampler::Nested; kwargs...) # Initialize particles # us are in unit space, vs are in prior space us, vs, logl = init_particles(rng, model, sampler) @@ -31,15 +56,15 @@ function step(rng, model, sampler::Nested; kwargs...) logzerr = sqrt(h / sampler.nactive) logvol -= 1 / sampler.nactive - sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) - state = (it = 1, ncall = ncall, us = us, vs = vs, logl = logl, logl_dead = logl_dead, - logz = logz, logzerr = logzerr, h = h, logvol = logvol, - since_update = since_update, has_bounds = false, active_bound = nothing) + sample = NestedTransition(u, u_dead, v_dead, logwt, logl_dead) + state = NestedState(1, ncall, us, vs, logl, logl_dead, + logz, logzerr, h, logvol, + since_update, false, nothing) return sample, state end -function step(rng, model, sampler, state; kwargs...) +function step(rng::Random.AbstractRNG, model, sampler::Nested, state::NestedState; kwargs...) ## Update bounds pointvol = exp(state.logvol) / sampler.nactive # check if ready for first update @@ -102,23 +127,24 @@ function step(rng, model, sampler, state; kwargs...) logvol = state.logvol - 1 / sampler.nactive ## prepare returns - sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) - state = (it = it, ncall = ncall, us = state.us, vs = state.vs, logl = state.logl, logl_dead = logl_dead, - logz = logz, logzerr = logzerr, h = h, logvol = logvol, - since_update = since_update, has_bounds = has_bounds, active_bound = active_bound) + sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) + state = NestedState(it, ncall, sate.us, state.vs, state.logl, logl_dead, + logz, logzerr, h, logvol, + since_update, has_bounds, active_bound) return sample, state end -function bundle_samples(samples, - model::AbstractModel, - sampler::Nested, - state, - ::Type{Chains}; - add_live=true, - param_names=missing, - check_wsum=true, - kwargs...) +function bundle_samples( + samples, + model::AbstractModel, + sampler::Nested, + state::NestedState, + ::Type{Chains}; + add_live=true, + param_names=missing, + check_wsum=true, + kwargs...) if add_live samples, state = add_live_points(samples, model, sampler, state) From 6380e7e27827871dc6b5d22d1c31301ff411b1e7 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 15 Sep 2023 16:28:19 +0100 Subject: [PATCH 2/9] init_particles refactor --- src/NestedSamplers.jl | 1 + src/step.jl | 43 ++++++++++++++++++++++++++++++++----------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/NestedSamplers.jl b/src/NestedSamplers.jl index 1c631fa8..829619ca 100644 --- a/src/NestedSamplers.jl +++ b/src/NestedSamplers.jl @@ -6,6 +6,7 @@ using Random: AbstractRNG, GLOBAL_RNG using AbstractMCMC using AbstractMCMC: AbstractSampler, + LogDensityModel, AbstractModel, samples, save!! diff --git a/src/step.jl b/src/step.jl index 856e3197..4c30612a 100644 --- a/src/step.jl +++ b/src/step.jl @@ -1,4 +1,3 @@ - #TODO it would be cool to make these parametric struct NestedState it::Int @@ -24,10 +23,20 @@ struct NestedTransition logl end -function step(rng::Random.AbstractRNG, model, sampler::Nested; kwargs...) +function step( + rng::Random.AbstractRNG, + model, + sampler::Nested; + init_params = nothing, + kwargs...) # Initialize particles # us are in unit space, vs are in prior space - us, vs, logl = init_particles(rng, model, sampler) + if init_params == nothing + us = rand(rng, T, sampler.ndims, sampler.nactive) + else + us = init_params + end + vs, logl = init_particles(rng, us, model) # Find least likely point logl_dead, idx_dead = findmin(logl) @@ -57,14 +66,20 @@ function step(rng::Random.AbstractRNG, model, sampler::Nested; kwargs...) logvol -= 1 / sampler.nactive sample = NestedTransition(u, u_dead, v_dead, logwt, logl_dead) - state = NestedState(1, ncall, us, vs, logl, logl_dead, + state = NestedState( + 1, ncall, us, vs, logl, logl_dead, logz, logzerr, h, logvol, since_update, false, nothing) return sample, state end -function step(rng::Random.AbstractRNG, model, sampler::Nested, state::NestedState; kwargs...) +function step( + rng::Random.AbstractRNG, + model, + sampler::Nested, + state::NestedState; + kwargs...) ## Update bounds pointvol = exp(state.logvol) / sampler.nactive # check if ready for first update @@ -128,7 +143,8 @@ function step(rng::Random.AbstractRNG, model, sampler::Nested, state::NestedStat ## prepare returns sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) - state = NestedState(it, ncall, sate.us, state.vs, state.logl, logl_dead, + state = NestedState( + it, ncall, sate.us, state.vs, state.logl, logl_dead, logz, logzerr, h, logvol, since_update, has_bounds, active_bound) @@ -192,17 +208,23 @@ function bundle_samples(samples, end ## Helpers +init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::Int, model) = + init_particles(rng, rand(rng, T, ndims, nactive), model) -init_particles(rng, ndims, nactive, model) = +init_particles(rng::Random.AbstractRNG, ndims::Int, nactive::Int, model) = init_particles(rng, Float64, ndims, nactive, model) -init_particles(rng, model, sampler) = +init_particles(rng::Random.AbstractRNG, model, sampler::Nested) = init_particles(rng, sampler.ndims, sampler.nactive, model) # loop and fill arrays, checking validity of points # will retry 100 times before erroring -function init_particles(rng, T, ndims, nactive, model) - us = rand(rng, T, ndims, nactive) +function init_particles( + rng::Random.AbstractRNG, + us::Matrix, + model + ) + vs_and_logl = mapslices( Base.Fix1(prior_transform_and_loglikelihood, model), us; dims=1 @@ -230,7 +252,6 @@ function init_particles(rng, T, ndims, nactive, model) return us, vs, logl end - # add remaining live points to `samples` function add_live_points(samples, model, sampler, state) logvol = -state.it / sampler.nactive - log(sampler.nactive) From 2f0a88bf1f0e6180c2c4311f9ed9c56e374366de Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 22 Sep 2023 15:28:04 +0100 Subject: [PATCH 3/9] init_particles multidispatch --- src/step.jl | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/src/step.jl b/src/step.jl index 4c30612a..7c6450aa 100644 --- a/src/step.jl +++ b/src/step.jl @@ -32,7 +32,7 @@ function step( # Initialize particles # us are in unit space, vs are in prior space if init_params == nothing - us = rand(rng, T, sampler.ndims, sampler.nactive) + us = rand(rng, Float64, sampler.ndims, sampler.nactive) else us = init_params end @@ -222,7 +222,40 @@ init_particles(rng::Random.AbstractRNG, model, sampler::Nested) = function init_particles( rng::Random.AbstractRNG, us::Matrix, - model + model::AbstractModel + ) + + vs_and_logl = mapslices( + Base.Fix1(prior_transform_and_loglikelihood, model), us; + dims=1 + ) + vs = mapreduce(first, hcat, vs_and_logl) + logl = dropdims(map(last, vs_and_logl), dims=1) + + ntries = 1 + while true + any(isfinite, logl) && break + rand!(rng, us) + vs_and_logl .= mapslices( + Base.Fix1(prior_transform_and_loglikelihood, model), us; + dims=1 + ) + vs .= mapreduce(first, hcat, vs_and_logl) + map!(last, logl, vs_and_logl) + ntries += 1 + ntries > 100 && error("After 100 attempts, could not initialize any live points with finite loglikelihood. Please check your prior transform and loglikelihood methods.") + end + + # force -Inf to be a finite but small number to keep estimators from breaking + @. logl[logl == -Inf] = -1e300 + + return us, vs, logl +end + +function init_particles( + rng::Random.AbstractRNG, + us::Matrix, + model::LogDensityModel ) vs_and_logl = mapslices( From e0dc149e636179a0399cddf932cb2fd113dbb2ed Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 22 Sep 2023 15:58:24 +0100 Subject: [PATCH 4/9] tests nearly working --- src/step.jl | 45 +++++++++++++-------------------------------- 1 file changed, 13 insertions(+), 32 deletions(-) diff --git a/src/step.jl b/src/step.jl index 7c6450aa..6bf34d74 100644 --- a/src/step.jl +++ b/src/step.jl @@ -12,7 +12,7 @@ struct NestedState logvol since_update::Int has_bounds::Bool - active_bound::Bool + active_bound end struct NestedTransition @@ -31,12 +31,7 @@ function step( kwargs...) # Initialize particles # us are in unit space, vs are in prior space - if init_params == nothing - us = rand(rng, Float64, sampler.ndims, sampler.nactive) - else - us = init_params - end - vs, logl = init_particles(rng, us, model) + us, vs, logl = init_particles(rng, model, sampler) # Find least likely point logl_dead, idx_dead = findmin(logl) @@ -66,10 +61,8 @@ function step( logvol -= 1 / sampler.nactive sample = NestedTransition(u, u_dead, v_dead, logwt, logl_dead) - state = NestedState( - 1, ncall, us, vs, logl, logl_dead, - logz, logzerr, h, logvol, - since_update, false, nothing) + state = NestedState(1, ncall, us, vs, logl, logl_dead, + logz, logzerr, h, logvol, since_update, false, nothing) return sample, state end @@ -142,11 +135,9 @@ function step( logvol = state.logvol - 1 / sampler.nactive ## prepare returns - sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) - state = NestedState( - it, ncall, sate.us, state.vs, state.logl, logl_dead, - logz, logzerr, h, logvol, - since_update, has_bounds, active_bound) + sample = NestedTransition(u, u_dead, v_dead, logwt, logl_dead) + state = NestedState(it, ncall, state.us, state.vs, state.logl, logl_dead, + logz, logzerr, h, logvol, since_update, has_bounds, active_bound) return sample, state end @@ -208,23 +199,17 @@ function bundle_samples(samples, end ## Helpers -init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::Int, model) = - init_particles(rng, rand(rng, T, ndims, nactive), model) -init_particles(rng::Random.AbstractRNG, ndims::Int, nactive::Int, model) = +init_particles(rng, ndims, nactive, model) = init_particles(rng, Float64, ndims, nactive, model) -init_particles(rng::Random.AbstractRNG, model, sampler::Nested) = +init_particles(rng, model, sampler) = init_particles(rng, sampler.ndims, sampler.nactive, model) # loop and fill arrays, checking validity of points # will retry 100 times before erroring -function init_particles( - rng::Random.AbstractRNG, - us::Matrix, - model::AbstractModel - ) - +function init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::Int, model::AbstractModel) + us = rand(rng, T, ndims, nactive) vs_and_logl = mapslices( Base.Fix1(prior_transform_and_loglikelihood, model), us; dims=1 @@ -252,12 +237,8 @@ function init_particles( return us, vs, logl end -function init_particles( - rng::Random.AbstractRNG, - us::Matrix, - model::LogDensityModel - ) - +function init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::Int, model::AbstractModel) + us = rand(rng, T, ndims, nactive) vs_and_logl = mapslices( Base.Fix1(prior_transform_and_loglikelihood, model), us; dims=1 From 1fcc099e28181235d0d41195e7c548a2a2246d50 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 22 Sep 2023 16:19:02 +0100 Subject: [PATCH 5/9] bundle sample --- src/step.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/step.jl b/src/step.jl index 6bf34d74..a5e890c0 100644 --- a/src/step.jl +++ b/src/step.jl @@ -143,7 +143,6 @@ function step( end function bundle_samples( - samples, model::AbstractModel, sampler::Nested, state::NestedState, From abeb8744b86a0d724d8d16dc09b573d9334d98c9 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 22 Sep 2023 16:21:43 +0100 Subject: [PATCH 6/9] bundle sample --- src/step.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/step.jl b/src/step.jl index a5e890c0..b3ea8fd7 100644 --- a/src/step.jl +++ b/src/step.jl @@ -142,15 +142,15 @@ function step( return sample, state end -function bundle_samples( - model::AbstractModel, - sampler::Nested, - state::NestedState, - ::Type{Chains}; - add_live=true, - param_names=missing, - check_wsum=true, - kwargs...) +function bundle_samples(samples, + model::AbstractModel, + sampler::Nested, + state, + ::Type{Chains}; + add_live=true, + param_names=missing, + check_wsum=true, + kwargs...) if add_live samples, state = add_live_points(samples, model, sampler, state) From b031b57599ef4666ebab88cfb9de8e88c12d9ebb Mon Sep 17 00:00:00 2001 From: jaimerz Date: Sun, 24 Sep 2023 11:02:06 +0100 Subject: [PATCH 7/9] no theta --- src/step.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/step.jl b/src/step.jl index b3ea8fd7..f54c4b8b 100644 --- a/src/step.jl +++ b/src/step.jl @@ -16,7 +16,7 @@ struct NestedState end struct NestedTransition - θ + #θ u v logwt @@ -60,7 +60,7 @@ function step( logzerr = sqrt(h / sampler.nactive) logvol -= 1 / sampler.nactive - sample = NestedTransition(u, u_dead, v_dead, logwt, logl_dead) + sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) state = NestedState(1, ncall, us, vs, logl, logl_dead, logz, logzerr, h, logvol, since_update, false, nothing) @@ -135,7 +135,7 @@ function step( logvol = state.logvol - 1 / sampler.nactive ## prepare returns - sample = NestedTransition(u, u_dead, v_dead, logwt, logl_dead) + sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) state = NestedState(it, ncall, state.us, state.vs, state.logl, logl_dead, logz, logzerr, h, logvol, since_update, has_bounds, active_bound) @@ -145,7 +145,7 @@ end function bundle_samples(samples, model::AbstractModel, sampler::Nested, - state, + state::NestedState, ::Type{Chains}; add_live=true, param_names=missing, @@ -175,7 +175,7 @@ end function bundle_samples(samples, model::AbstractModel, sampler::Nested, - state, + state::NestedState, ::Type{Array}; add_live=true, check_wsum=true, @@ -236,7 +236,7 @@ function init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::I return us, vs, logl end -function init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::Int, model::AbstractModel) +function init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::Int, model::LogDensityModel) us = rand(rng, T, ndims, nactive) vs_and_logl = mapslices( Base.Fix1(prior_transform_and_loglikelihood, model), us; From 338e6485ce8a03f9ef9fa1dd5945ae8074d21eac Mon Sep 17 00:00:00 2001 From: jaimerz Date: Sun, 24 Sep 2023 11:32:46 +0100 Subject: [PATCH 8/9] init params --- src/staticsampler.jl | 6 +++++- src/step.jl | 17 ++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/staticsampler.jl b/src/staticsampler.jl index 1c7be396..f9e2384e 100644 --- a/src/staticsampler.jl +++ b/src/staticsampler.jl @@ -1,6 +1,7 @@ # Sampler and model implementations struct Nested{B, P <: AbstractProposal} <: AbstractSampler + type::Type ndims::Int nactive::Int bounds::B @@ -53,6 +54,7 @@ function Nested(ndims, enlarge = 1.25, min_ncall=2nactive, min_eff=0.10, + type = Float64, kwargs...) nactive < 2ndims && @warn "Using fewer than 2ndim ($(2ndims)) active points is discouraged" @@ -72,7 +74,9 @@ function Nested(ndims, update_interval_frac = get(kwargs, :update_interval, default_update_interval(proposal, ndims)) update_interval = round(Int, update_interval_frac * nactive) - return Nested(ndims, + return Nested( + type, + ndims, nactive, bounds, enlarge, diff --git a/src/step.jl b/src/step.jl index f54c4b8b..3f20d962 100644 --- a/src/step.jl +++ b/src/step.jl @@ -31,7 +31,12 @@ function step( kwargs...) # Initialize particles # us are in unit space, vs are in prior space - us, vs, logl = init_particles(rng, model, sampler) + if init_params != nothing && size(init_params) == (sampler.ndims,sampler.nactive) + us = init_params + else + us = rand(rng, sampler.type, sampler.ndims, sampler.nactive) + end + us, vs, logl = init_particles(rng, us, model) # Find least likely point logl_dead, idx_dead = findmin(logl) @@ -60,7 +65,7 @@ function step( logzerr = sqrt(h / sampler.nactive) logvol -= 1 / sampler.nactive - sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) + sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) #NestedTransition(u_dead, v_dead, logwt, logl_dead) state = NestedState(1, ncall, us, vs, logl, logl_dead, logz, logzerr, h, logvol, since_update, false, nothing) @@ -135,7 +140,7 @@ function step( logvol = state.logvol - 1 / sampler.nactive ## prepare returns - sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) + sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) #NestedTransition(u_dead, v_dead, logwt, logl_dead) state = NestedState(it, ncall, state.us, state.vs, state.logl, logl_dead, logz, logzerr, h, logvol, since_update, has_bounds, active_bound) @@ -207,8 +212,7 @@ init_particles(rng, model, sampler) = # loop and fill arrays, checking validity of points # will retry 100 times before erroring -function init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::Int, model::AbstractModel) - us = rand(rng, T, ndims, nactive) +function init_particles(rng::Random.AbstractRNG, us::Matrix, model::AbstractModel) vs_and_logl = mapslices( Base.Fix1(prior_transform_and_loglikelihood, model), us; dims=1 @@ -236,8 +240,7 @@ function init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::I return us, vs, logl end -function init_particles(rng::Random.AbstractRNG, T::Type, ndims::Int, nactive::Int, model::LogDensityModel) - us = rand(rng, T, ndims, nactive) +function init_particles(rng::Random.AbstractRNG, us::Matrix, model::LogDensityModel) vs_and_logl = mapslices( Base.Fix1(prior_transform_and_loglikelihood, model), us; dims=1 From 3a55cc5b3add28492cefc5d497ff2231f7ec837d Mon Sep 17 00:00:00 2001 From: jaimerz Date: Sun, 24 Sep 2023 11:53:24 +0100 Subject: [PATCH 9/9] init particles for logdensitymodel --- src/step.jl | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/step.jl b/src/step.jl index 3f20d962..e0d4c8d7 100644 --- a/src/step.jl +++ b/src/step.jl @@ -34,6 +34,7 @@ function step( if init_params != nothing && size(init_params) == (sampler.ndims,sampler.nactive) us = init_params else + @warn "Init params have wrong dimensions, sampling new ones!" us = rand(rng, sampler.type, sampler.ndims, sampler.nactive) end us, vs, logl = init_particles(rng, us, model) @@ -65,7 +66,8 @@ function step( logzerr = sqrt(h / sampler.nactive) logvol -= 1 / sampler.nactive - sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) #NestedTransition(u_dead, v_dead, logwt, logl_dead) + sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) + #sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) state = NestedState(1, ncall, us, vs, logl, logl_dead, logz, logzerr, h, logvol, since_update, false, nothing) @@ -140,7 +142,8 @@ function step( logvol = state.logvol - 1 / sampler.nactive ## prepare returns - sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) #NestedTransition(u_dead, v_dead, logwt, logl_dead) + sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) + #sample = NestedTransition(u_dead, v_dead, logwt, logl_dead) state = NestedState(it, ncall, state.us, state.vs, state.logl, logl_dead, logz, logzerr, h, logvol, since_update, has_bounds, active_bound) @@ -241,31 +244,23 @@ function init_particles(rng::Random.AbstractRNG, us::Matrix, model::AbstractMode end function init_particles(rng::Random.AbstractRNG, us::Matrix, model::LogDensityModel) - vs_and_logl = mapslices( - Base.Fix1(prior_transform_and_loglikelihood, model), us; - dims=1 - ) - vs = mapreduce(first, hcat, vs_and_logl) - logl = dropdims(map(last, vs_and_logl), dims=1) + logdensity = model.logdensity + dims, nactive = size(us) + logls = [logdensity(us[:, i]) for i in 1:nactive] ntries = 1 while true - any(isfinite, logl) && break + any(isfinite, logls) && break rand!(rng, us) - vs_and_logl .= mapslices( - Base.Fix1(prior_transform_and_loglikelihood, model), us; - dims=1 - ) - vs .= mapreduce(first, hcat, vs_and_logl) - map!(last, logl, vs_and_logl) + logls = [logdensity(us[:, i]) for i in 1:nactive] ntries += 1 ntries > 100 && error("After 100 attempts, could not initialize any live points with finite loglikelihood. Please check your prior transform and loglikelihood methods.") end # force -Inf to be a finite but small number to keep estimators from breaking - @. logl[logl == -Inf] = -1e300 + @. logls[logls == -Inf] = -1e300 - return us, vs, logl + return us, us, logls end # add remaining live points to `samples`