forked from blackjax-devs/blackjax
-
Notifications
You must be signed in to change notification settings - Fork 2
Add convergence helper and error tracking to nested sampling #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
williamjameshandley
wants to merge
1
commit into
nested_sampling
Choose a base branch
from
ns-error-tracking-and-convergence
base: nested_sampling
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,6 +66,14 @@ class NSState(NamedTuple): | |
| The accumulated log evidence estimate from the "dead" points . | ||
| logZ_live | ||
| The current estimate of the log evidence contribution from the live points. | ||
| logZ_error | ||
| The current estimate of the error on logZ. | ||
| H | ||
| The current estimate of the information (negative entropy) in nats. | ||
| i_eff | ||
| The effective number of iterations | ||
| n_eff | ||
| The effective sample size of the current set of live particles. | ||
| inner_kernel_params | ||
| A dictionary of parameters for the inner kernel. | ||
| """ | ||
|
|
@@ -78,6 +86,10 @@ class NSState(NamedTuple): | |
| logX: Array # The current log-volume estimate | ||
| logZ: Array # The accumulated evidence estimate | ||
| logZ_live: Array # The current evidence estimate | ||
| logZ_error: Array # The current error estimate on logZ | ||
| H: Array # The current information estimate | ||
| i_eff: Array # The effective number of iterations | ||
| n_eff: Array # The effective sample size of the current particles | ||
| inner_kernel_params: Dict # Parameters for the inner kernel | ||
|
|
||
|
|
||
|
|
@@ -210,6 +222,10 @@ def init( | |
| loglikelihood_birth: Array = -jnp.nan, | ||
| logX: Optional[Array] = 0.0, | ||
| logZ: Optional[Array] = -jnp.inf, | ||
| logZ_error: Optional[Array] = 0.0, | ||
| H: Optional[Array] = 0.0, | ||
| i_eff: Optional[Array] = 0.0, | ||
| n_eff: Optional[Array] = 0.0, | ||
| ) -> NSState: | ||
| """Initializes the Nested Sampler state. | ||
|
|
||
|
|
@@ -244,6 +260,11 @@ def init( | |
| logX = jnp.array(logX, dtype=dtype) | ||
| logZ = jnp.array(logZ, dtype=dtype) | ||
| logZ_live = logmeanexp(loglikelihood) + logX | ||
| logZ_error = jnp.array(logZ_error, dtype=dtype) | ||
| H = jnp.array(H, dtype=dtype) | ||
| i_eff = jnp.array(i_eff, dtype=dtype) | ||
| n_eff = jnp.array(n_eff, dtype=dtype) | ||
| n_eff = jnp.where(n_eff == 0.0, len(loglikelihood), n_eff) | ||
| inner_kernel_params: Dict = {} | ||
| return NSState( | ||
| particles, | ||
|
|
@@ -254,6 +275,10 @@ def init( | |
| logX, | ||
| logZ, | ||
| logZ_live, | ||
| logZ_error, | ||
| H, | ||
| i_eff, | ||
| n_eff, | ||
| inner_kernel_params, | ||
| ) | ||
|
|
||
|
|
@@ -350,8 +375,8 @@ def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: | |
| pid = state.pid.at[target_update_idx].set(state.pid[start_idx]) | ||
|
|
||
| # Update the run-time information | ||
| logX, logZ, logZ_live = update_ns_runtime_info( | ||
| state.logX, state.logZ, loglikelihood, dead_loglikelihood | ||
| logX, logZ, logZ_live, logZ_error, H, i_eff, n_eff = update_ns_runtime_info( | ||
| state, loglikelihood, dead_loglikelihood | ||
| ) | ||
|
|
||
| # Return updated state and info | ||
|
|
@@ -364,6 +389,10 @@ def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: | |
| logX, | ||
| logZ, | ||
| logZ_live, | ||
| logZ_error, | ||
| H, | ||
| i_eff, | ||
| n_eff, | ||
| state.inner_kernel_params, | ||
| ) | ||
| info = NSInfo( | ||
|
|
@@ -428,20 +457,29 @@ def delete_fn( | |
|
|
||
|
|
||
| def update_ns_runtime_info( | ||
| logX: Array, logZ: Array, loglikelihood: Array, dead_loglikelihood: Array | ||
| ) -> tuple[Array, Array, Array]: | ||
| state: NSState, loglikelihood: Array, dead_loglikelihood: Array | ||
| ) -> tuple[Array, Array, Array, Array, Array, Array, Array]: | ||
| num_particles = len(loglikelihood) | ||
| num_deleted = len(dead_loglikelihood) | ||
| num_live = jnp.arange(num_particles, num_particles - num_deleted, -1) | ||
| delta_logX = -1 / num_live | ||
| logX = logX + jnp.cumsum(delta_logX) | ||
| log_delta_X = logX + jnp.log(1 - jnp.exp(delta_logX)) | ||
| logX = state.logX + jnp.cumsum(delta_logX) | ||
| log_delta_X = logX + jnp.log1p(-jnp.exp(delta_logX)) | ||
| log_delta_Z = dead_loglikelihood + log_delta_X | ||
|
|
||
| delta_logZ = logsumexp(log_delta_Z) | ||
| logZ = jnp.logaddexp(logZ, delta_logZ) | ||
| logZ = jnp.logaddexp(state.logZ, delta_logZ) | ||
| A = state.i_eff / state.n_eff + jnp.sum(1 / num_live) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this has to be all calculated online, would it be useful to put any of the calculation logic into the utils, in case they are useful for various post hoc analysis bits |
||
| B = state.i_eff / state.n_eff**2 + jnp.sum(1 / num_live**2) | ||
| i_eff = A**2 / B | ||
| n_eff = A / B | ||
| H = jnp.nan_to_num(jnp.exp(state.logZ - logZ) * (state.H + state.logZ), 0.0) | ||
| H += jnp.sum(jnp.exp(log_delta_Z - logZ) * dead_loglikelihood) - logZ | ||
| logZ_error = jnp.sqrt(H / n_eff) | ||
| logZ_live = logmeanexp(loglikelihood) + logX[-1] | ||
| return logX[-1], logZ, logZ_live | ||
| return logX[-1], logZ, logZ_live, logZ_error, H, i_eff, n_eff | ||
|
|
||
|
|
||
| # H ~ Σ 1/n ± sqrt(Σ 1/n^2) | ||
|
|
||
|
|
||
| def logmeanexp(x: Array) -> Array: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can any of this be calculated “offline”, i.e. in the utils rather than in the online function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could all be calculated offline, but there is a virtue in accumulation here that we appreciated in the first place for logZ (this is not true for logZ_live, which we do not accumulate, and could be calculated as part of
converged). My main reason for doing this is to be able to construct more advanced convergence criteria that e.g. incorporated an evidence error, or a KL-based stopping criterion (which John Skilling prefers).