Skip to content

Conversation

@williamjameshandley
Copy link

Summary

This PR introduces two key improvements to the nested sampling implementation:

  1. Convergence helper function - Adds a converged() utility function (discussed with @zwei-beiner) that properly handles the "full plateau" case where all live points have identical likelihoods. This addresses an edge case where our zero-weight check in the delete function doesn't prevent the algorithm from continuing - it just spreads the selection more fairly due to jax.random.choice's surprising behavior with zero/NaN weights.

  2. Error tracking and information metrics - Extends NSState to track:

    • logZ_error: Error estimate on the log evidence
    • H: Information (negative entropy) in nats
    • i_eff: Effective number of iterations
    • n_eff: Effective sample size

Changes

blackjax/ns/base.py

  • Added new fields to NSState: logZ_error, H, i_eff, n_eff
  • Updated init() to initialize these fields
  • Modified update_ns_runtime_info() to compute error estimates and information metrics
  • Refactored calculations for clarity (extracted common terms A and B)

blackjax/ns/utils.py

  • Added converged() function that checks:
    • If all live points have identical likelihood (full plateau case)
    • If live evidence contribution is below precision criterion

docs/examples/nested_sampling.py

  • Updated to use the new converged() utility function

Testing

The changes maintain backward compatibility and all existing tests pass. The convergence helper provides a cleaner termination condition that handles edge cases more robustly.

🤖 Generated with Claude Code

Co-Authored-By: Claude [email protected]

- Add logZ_error, H (information), i_eff, and n_eff fields to NSState
- Update runtime info calculation to compute these new metrics
- Add converged() utility function for checking termination criteria
- Update example to use new convergence function

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
Copy link
Collaborator

@yallup yallup left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very useful addition, have some queries about bloat of info, but if it is needed to be in the state then I’m happy with taking this on

logX: Array # The current log-volume estimate
logZ: Array # The accumulated evidence estimate
logZ_live: Array # The current evidence estimate
logZ_error: Array # The current error estimate on logZ
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can any of this be calculated “offline”, i.e. in the utils rather than in the online function?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could all be calculated offline, but there is a virtue in accumulation here that we appreciated in the first place for logZ (this is not true for logZ_live, which we do not accumulate, and could be calculated as part of converged). My main reason for doing this is to be able to construct more advanced convergence criteria that e.g. incorporated an evidence error, or a KL-based stopping criterion (which John Skilling prefers).

delta_logZ = logsumexp(log_delta_Z)
logZ = jnp.logaddexp(logZ, delta_logZ)
logZ = jnp.logaddexp(state.logZ, delta_logZ)
A = state.i_eff / state.n_eff + jnp.sum(1 / num_live)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this has to be all calculated online, would it be useful to put any of the calculation logic into the utils, in case they are useful for various post hoc analysis bits

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants