Open
Conversation
…r.py Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…ception handling Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…sting Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…gma8 and fixed_cosmo_params Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…diffrax, jaxopt) Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…est.importorskip for simulation tests Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…ing-scheme Add fsigma8 velocity fit notebook, simulation tests, and make simulation deps optional
There was a problem hiding this comment.
Pull request overview
Adds a new flip.simulation subpackage that enables differentiable forward-modeling of peculiar-velocity data using JaxPM (LPT or N-body via diffrax), plus a small optimizer wrapper to fit cosmological parameters.
Changes:
- Introduces JAX-based simulation field generation (
generate.py) and a Gaussian velocity-field likelihood (likelihood.py). - Adds a
SimulationFitterwrapper aroundjaxoptsolvers for gradient-based parameter fits. - Adds a new
simulationoptional-dependency extra, a demonstration notebook, and a dedicated test suite (skipped when JAX stack isn’t installed).
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
flip/simulation/generate.py |
Implements LPT/N-body field generation, interpolation helpers, and fσ8 utilities. |
flip/simulation/likelihood.py |
Adds a JAX Gaussian likelihood comparing simulated vs observed LOS velocities. |
flip/simulation/fitter.py |
Adds a small jaxopt-based optimizer wrapper for the simulation likelihood. |
flip/simulation/__init__.py |
Exposes the simulation submodules with optional-dependency import handling. |
pyproject.toml |
Adds a simulation extra for JAX/JaxPM/JaxOpt/Diffrax dependencies. |
test/test_simulation.py |
Adds unit/smoke tests for generate/likelihood/fitter covering LPT and N-body paths. |
notebook/fit_simulation_velocity.ipynb |
Adds an end-to-end example notebook for fitting σ8 / fσ₈. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+599
to
+619
| def compute_fsigma8(cosmo, a=1.0): | ||
| """Compute the linear growth parameter f*sigma_8 from a cosmology. | ||
|
|
||
| Evaluates :math:`f\\sigma_8 = f(a) \\cdot \\sigma_8`, where :math:`f` is | ||
| the logarithmic growth rate :math:`f = d\\ln D / d\\ln a` and | ||
| :math:`\\sigma_8` is the RMS matter fluctuation amplitude at | ||
| 8 Mpc/h. | ||
|
|
||
| Args: | ||
| cosmo (jax_cosmo.Cosmology): Cosmological parameters. Create with | ||
| :func:`get_cosmology`. | ||
| a (float): Scale factor at which to evaluate the growth rate. | ||
| Default 1.0 (z=0). | ||
|
|
||
| Returns: | ||
| float: :math:`f \\cdot \\sigma_8` dimensionless growth parameter. | ||
| """ | ||
| _require_jaxpm("compute_fsigma8") | ||
| a_arr = jnp.atleast_1d(a) | ||
| f = jc.background.growth_rate(cosmo, a_arr)[0] | ||
| return f * cosmo.sigma8 |
Comment on lines
+622
to
+639
| def radec_to_cartesian(ra, dec, r_com): | ||
| """Convert spherical sky coordinates to Cartesian positions. | ||
|
|
||
| Args: | ||
| ra (array-like): Right ascension in degrees. | ||
| dec (array-like): Declination in degrees. | ||
| r_com (array-like): Comoving distance in Mpc/h. | ||
|
|
||
| Returns: | ||
| jnp.ndarray: Cartesian positions in Mpc/h, shape ``(N, 3)``. | ||
| """ | ||
| _require_jaxpm("radec_to_cartesian") | ||
| ra_rad = jnp.deg2rad(jnp.asarray(ra)) | ||
| dec_rad = jnp.deg2rad(jnp.asarray(dec)) | ||
| r = jnp.asarray(r_com) | ||
| x = r * jnp.cos(dec_rad) * jnp.cos(ra_rad) | ||
| y = r * jnp.cos(dec_rad) * jnp.sin(ra_rad) | ||
| z = r * jnp.sin(dec_rad) |
Comment on lines
+15
to
+22
| from . import fitter, generate, likelihood | ||
| except ImportError as e: | ||
| log.add( | ||
| f"Could not import flip.simulation modules ({e}). " | ||
| "Install the optional dependencies with: " | ||
| "pip install jaxpm jaxopt jax_cosmo diffrax", | ||
| level="warning", | ||
| ) |
Comment on lines
+81
to
+85
| chol = jsc.linalg.cho_factor(observed_variance) | ||
| logdet = 2.0 * jnp.sum(jnp.log(jnp.diag(chol[0]))) | ||
| chi2 = jnp.dot(residual, jsc.linalg.cho_solve(chol, residual)) | ||
| n = residual.size | ||
| return -0.5 * (chi2 + logdet + n * jnp.log(2.0 * jnp.pi)) |
Comment on lines
+43
to
+50
| simulation = [ | ||
| "jax", | ||
| "jaxlib", | ||
| "jaxpm", | ||
| "jaxopt", | ||
| "jax_cosmo", | ||
| "diffrax", | ||
| ] |
| "id": "code-install", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": "%%capture\n!pip install git+https://github.com/corentinravoux/flip \"jaxpm>=0.1\" \"diffrax>=0.5\" jax_cosmo jaxopt" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Start of the work on the forward modeling