Skip to content

Simulation test agent#103

Open
corentinravoux wants to merge 12 commits intomainfrom
simulation_test_agent
Open

Simulation test agent#103
corentinravoux wants to merge 12 commits intomainfrom
simulation_test_agent

Conversation

@corentinravoux
Copy link
Copy Markdown
Owner

Start of the work on the forward modeling

Copilot AI and others added 12 commits March 11, 2026 15:01
…r.py

Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…ception handling

Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…sting

Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…gma8 and fixed_cosmo_params

Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…diffrax, jaxopt)

Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…est.importorskip for simulation tests

Co-authored-by: corentinravoux <52483673+corentinravoux@users.noreply.github.com>
…ing-scheme

Add fsigma8 velocity fit notebook, simulation tests, and make simulation deps optional
Copilot AI review requested due to automatic review settings March 13, 2026 16:23
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new flip.simulation subpackage that enables differentiable forward-modeling of peculiar-velocity data using JaxPM (LPT or N-body via diffrax), plus a small optimizer wrapper to fit cosmological parameters.

Changes:

  • Introduces JAX-based simulation field generation (generate.py) and a Gaussian velocity-field likelihood (likelihood.py).
  • Adds a SimulationFitter wrapper around jaxopt solvers for gradient-based parameter fits.
  • Adds a new simulation optional-dependency extra, a demonstration notebook, and a dedicated test suite (skipped when JAX stack isn’t installed).

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
flip/simulation/generate.py Implements LPT/N-body field generation, interpolation helpers, and fσ8 utilities.
flip/simulation/likelihood.py Adds a JAX Gaussian likelihood comparing simulated vs observed LOS velocities.
flip/simulation/fitter.py Adds a small jaxopt-based optimizer wrapper for the simulation likelihood.
flip/simulation/__init__.py Exposes the simulation submodules with optional-dependency import handling.
pyproject.toml Adds a simulation extra for JAX/JaxPM/JaxOpt/Diffrax dependencies.
test/test_simulation.py Adds unit/smoke tests for generate/likelihood/fitter covering LPT and N-body paths.
notebook/fit_simulation_velocity.ipynb Adds an end-to-end example notebook for fitting σ8 / fσ₈.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +599 to +619
def compute_fsigma8(cosmo, a=1.0):
"""Compute the linear growth parameter f*sigma_8 from a cosmology.

Evaluates :math:`f\\sigma_8 = f(a) \\cdot \\sigma_8`, where :math:`f` is
the logarithmic growth rate :math:`f = d\\ln D / d\\ln a` and
:math:`\\sigma_8` is the RMS matter fluctuation amplitude at
8 Mpc/h.

Args:
cosmo (jax_cosmo.Cosmology): Cosmological parameters. Create with
:func:`get_cosmology`.
a (float): Scale factor at which to evaluate the growth rate.
Default 1.0 (z=0).

Returns:
float: :math:`f \\cdot \\sigma_8` dimensionless growth parameter.
"""
_require_jaxpm("compute_fsigma8")
a_arr = jnp.atleast_1d(a)
f = jc.background.growth_rate(cosmo, a_arr)[0]
return f * cosmo.sigma8
Comment on lines +622 to +639
def radec_to_cartesian(ra, dec, r_com):
"""Convert spherical sky coordinates to Cartesian positions.

Args:
ra (array-like): Right ascension in degrees.
dec (array-like): Declination in degrees.
r_com (array-like): Comoving distance in Mpc/h.

Returns:
jnp.ndarray: Cartesian positions in Mpc/h, shape ``(N, 3)``.
"""
_require_jaxpm("radec_to_cartesian")
ra_rad = jnp.deg2rad(jnp.asarray(ra))
dec_rad = jnp.deg2rad(jnp.asarray(dec))
r = jnp.asarray(r_com)
x = r * jnp.cos(dec_rad) * jnp.cos(ra_rad)
y = r * jnp.cos(dec_rad) * jnp.sin(ra_rad)
z = r * jnp.sin(dec_rad)
Comment on lines +15 to +22
from . import fitter, generate, likelihood
except ImportError as e:
log.add(
f"Could not import flip.simulation modules ({e}). "
"Install the optional dependencies with: "
"pip install jaxpm jaxopt jax_cosmo diffrax",
level="warning",
)
Comment on lines +81 to +85
chol = jsc.linalg.cho_factor(observed_variance)
logdet = 2.0 * jnp.sum(jnp.log(jnp.diag(chol[0])))
chi2 = jnp.dot(residual, jsc.linalg.cho_solve(chol, residual))
n = residual.size
return -0.5 * (chi2 + logdet + n * jnp.log(2.0 * jnp.pi))
Comment on lines +43 to +50
simulation = [
"jax",
"jaxlib",
"jaxpm",
"jaxopt",
"jax_cosmo",
"diffrax",
]
"id": "code-install",
"metadata": {},
"outputs": [],
"source": "%%capture\n!pip install git+https://github.com/corentinravoux/flip \"jaxpm>=0.1\" \"diffrax>=0.5\" jax_cosmo jaxopt"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants