diff --git a/flip/simulation/__init__.py b/flip/simulation/__init__.py index e69de29..afd32e1 100644 --- a/flip/simulation/__init__.py +++ b/flip/simulation/__init__.py @@ -0,0 +1,22 @@ +"""Init file of the flip.simulation package. + +This package provides tools for forward-model cosmological inference using +differentiable N-body simulations. The ``jaxpm``, ``jaxopt``, ``jax_cosmo``, +and ``diffrax`` packages are optional and can be installed with:: + + pip install jaxpm jaxopt jax_cosmo diffrax +""" + +from flip.utils import create_log + +log = create_log() + +try: + from . import fitter, generate, likelihood +except ImportError as e: + log.add( + f"Could not import flip.simulation modules ({e}). " + "Install the optional dependencies with: " + "pip install jaxpm jaxopt jax_cosmo diffrax", + level="warning", + ) diff --git a/flip/simulation/fitter.py b/flip/simulation/fitter.py new file mode 100644 index 0000000..ed71cfd --- /dev/null +++ b/flip/simulation/fitter.py @@ -0,0 +1,221 @@ +"""Minimization of the simulation likelihood using jaxopt. + +This module provides a :class:`SimulationFitter` that minimizes the negative +log-likelihood of a JaxPM forward simulation given observed peculiar velocity +data using gradient-based optimizers from the ``jaxopt`` library. + +All optimizers support automatic differentiation through the JAX computational +graph, enabling gradient-based optimization over cosmological parameters. + +Examples: + >>> from flip.simulation.fitter import SimulationFitter + >>> fitter = SimulationFitter( + ... likelihood=lik, + ... initial_params={"omega_m": 0.3, "sigma8": 0.8}, + ... solver="LBFGS", + ... maxiter=200, + ... ) + >>> best_params = fitter.run() + >>> print(best_params) +""" + +from flip.utils import create_log + +log = create_log() + +try: + import jax.numpy as jnp + import jaxopt + + jaxopt_installed = True + +except ImportError: + jaxopt_installed = False + log.add( + "Install jaxopt to use the SimulationFitter", + level="warning", + ) + +#: Mapping from solver name string to the corresponding jaxopt class. +if jaxopt_installed: + _AVAILABLE_SOLVERS = { + "LBFGS": jaxopt.LBFGS, + "LBFGSB": jaxopt.LBFGSB, + "BFGS": jaxopt.BFGS, + "GradientDescent": jaxopt.GradientDescent, + } +else: + _AVAILABLE_SOLVERS = {} + + +class SimulationFitter: + """Minimize the simulation likelihood over cosmological parameters. + + Uses gradient-based optimization from ``jaxopt`` to find the cosmological + parameter values that maximize the likelihood of the observed velocity + field under the JaxPM forward simulation. + + Parameters are represented internally as a flat JAX array during + optimization and converted to/from a parameter dictionary for the + likelihood interface. + + Args: + likelihood (callable): Callable that accepts a parameter dict and + returns a scalar loss value (negative log-likelihood). Typically + an instance of + :class:`~flip.simulation.likelihood.VelocityFieldLikelihood`. + initial_params (dict): Initial cosmological parameter values, e.g. + ``{"omega_m": 0.3, "sigma8": 0.8}``. All values must be + Python/NumPy floats (not JAX arrays) at construction time. + bounds (tuple[dict, dict] | None): Optional box constraints as + ``(lower_bounds_dict, upper_bounds_dict)`` where each dict has + the same keys as ``initial_params``. Only applied when + ``solver="LBFGSB"``. Default ``None`` (unconstrained). + solver (str): Name of the jaxopt solver to use. One of ``"LBFGS"`` + (default), ``"LBFGSB"``, ``"BFGS"``, or + ``"GradientDescent"``. + maxiter (int): Maximum number of optimizer iterations. Default 100. + solver_kwargs (dict | None): Additional keyword arguments forwarded + to the jaxopt solver constructor (e.g. ``tol``, ``stepsize``). + + Raises: + ValueError: If ``solver`` is not one of the supported names. + + Examples: + >>> fitter = SimulationFitter( + ... likelihood=lik, + ... initial_params={"omega_m": 0.3, "sigma8": 0.8}, + ... bounds=({"omega_m": 0.1, "sigma8": 0.3}, + ... {"omega_m": 0.9, "sigma8": 1.5}), + ... solver="LBFGSB", + ... maxiter=200, + ... ) + >>> best_params = fitter.run() + """ + + def __init__( + self, + likelihood, + initial_params, + bounds=None, + solver="LBFGS", + maxiter=100, + solver_kwargs=None, + ): + if not jaxopt_installed: + raise ImportError( + "'SimulationFitter' requires jaxopt. " + "Install it with: pip install jaxopt" + ) + if solver not in _AVAILABLE_SOLVERS: + raise ValueError( + f"Solver '{solver}' is not supported. " + f"Choose one of: {list(_AVAILABLE_SOLVERS.keys())}" + ) + self.likelihood = likelihood + self.initial_params = initial_params + self.bounds = bounds + self.solver_name = solver + self.maxiter = maxiter + self.solver_kwargs = solver_kwargs or {} + self._result = None + + # Keep an ordered list of parameter names for array conversion + self._param_names = list(initial_params.keys()) + + # ------------------------------------------------------------------ + # Internal helpers for dict <-> array conversion + # ------------------------------------------------------------------ + + def _to_array(self, params_dict): + """Convert parameter dict to a flat JAX array. + + Args: + params_dict (dict): Parameter name -> value mapping. + + Returns: + jnp.ndarray: 1-D array of shape ``(n_params,)``. + """ + return jnp.array([params_dict[k] for k in self._param_names]) + + def _to_dict(self, params_array): + """Convert a flat array back to a parameter dict. + + Args: + params_array (jnp.ndarray): 1-D array of shape ``(n_params,)``. + + Returns: + dict: Parameter name -> scalar value mapping. + """ + return {k: params_array[i] for i, k in enumerate(self._param_names)} + + def _objective(self, params_array): + """Wrap the likelihood so it accepts a flat array. + + Args: + params_array (jnp.ndarray): 1-D parameter array. + + Returns: + float: Negative log-likelihood. + """ + return self.likelihood(self._to_dict(params_array)) + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def run(self): + """Run the optimization and return the best-fit parameters. + + After completion, the raw jaxopt result is available via + :attr:`result`. + + Returns: + dict: Best-fit cosmological parameter values, with the same keys + as ``initial_params``. + """ + initial_array = self._to_array(self.initial_params) + + solver_cls = _AVAILABLE_SOLVERS[self.solver_name] + kwargs = { + "fun": self._objective, + "maxiter": self.maxiter, + **self.solver_kwargs, + } + + solver = solver_cls(**kwargs) + + if self.bounds is not None and self.solver_name == "LBFGSB": + lower = jnp.array( + [self.bounds[0].get(k, -jnp.inf) for k in self._param_names] + ) + upper = jnp.array( + [self.bounds[1].get(k, jnp.inf) for k in self._param_names] + ) + result = solver.run(initial_array, bounds=(lower, upper)) + else: + result = solver.run(initial_array) + self._result = result + + try: + n_iter = result.state.iter_num + final_value = result.state.value + log.add( + f"SimulationFitter ({self.solver_name}) finished after " + f"{n_iter} iterations. " + f"Final loss: {final_value:.6g}" + ) + except (AttributeError, KeyError): + log.add(f"SimulationFitter ({self.solver_name}) optimization complete.") + + return self._to_dict(result.params) + + @property + def result(self): + """Raw jaxopt result from the last call to :meth:`run`. + + Returns: + jaxopt.OptStep | None: Result object, or ``None`` if :meth:`run` + has not been called yet. + """ + return self._result diff --git a/flip/simulation/generate.py b/flip/simulation/generate.py new file mode 100644 index 0000000..6e2373e --- /dev/null +++ b/flip/simulation/generate.py @@ -0,0 +1,640 @@ +"""Differentiable simulation of density and velocity fields using JaxPM. + +This module provides JAX-differentiable functions to generate large-scale +structure simulations via two methods: + +* **N-body** (default): Full particle-mesh N-body integration using the + ``diffrax`` ODE solver. Particles are displaced from lattice positions using + first-order LPT at an early scale factor and then evolved forward in time by + solving the equations of motion with the ``diffrax`` library. This is the + primary method and the one exposed by :func:`generate_density_and_velocity`. + +* **LPT**: First-order Lagrangian Perturbation Theory (Zel'dovich + approximation) only. Faster but less accurate than the full N-body + integration. Available via :func:`generate_density_and_velocity_lpt` and + useful for testing or rapid parameter scans. + +All operations are implemented in JAX to support automatic differentiation with +respect to cosmological parameters. + +Notes: + The ``jaxpm``, ``jax_cosmo``, and ``diffrax`` packages must be installed + to use this module:: + + pip install jaxpm jax_cosmo diffrax + +Examples: + N-body simulation (default): + + >>> import jax + >>> import jax.numpy as jnp + >>> import jax_cosmo as jc + >>> from flip.simulation import generate + >>> cosmo = generate.get_cosmology(omega_m=0.3, sigma8=0.8) + >>> seed = jax.random.PRNGKey(0) + >>> density, velocity = generate.generate_density_and_velocity( + ... cosmo, mesh_shape=(32, 32, 32), box_size=[256., 256., 256.], + ... seed=seed, + ... ) + + LPT-only simulation (for testing): + + >>> density, velocity = generate.generate_density_and_velocity_lpt( + ... cosmo, mesh_shape=(32, 32, 32), box_size=[256., 256., 256.], + ... seed=seed, + ... ) +""" + +from flip.utils import create_log + +log = create_log() + +try: + import jax + import jax.numpy as jnp + import jax_cosmo as jc + from jaxpm.distributed import fft3d, ifft3d, normal_field + from jaxpm.growth import growth_factor, growth_rate + from jaxpm.kernels import fftk + from jaxpm.painting import cic_paint_dx + from jaxpm.pm import make_diffrax_ode, pm_forces + + jaxpm_installed = True + +except ImportError: + jaxpm_installed = False + log.add( + "Install jaxpm, jax_cosmo and diffrax to use the simulation generate module", + level="warning", + ) + +#: Conversion factor: 1 Mpc/h * H_0 = 100 km/s. +#: The h factors cancel because H_0 = 100h km/s/Mpc and 1 Mpc/h = (1/h) Mpc. +_H0_UNIT = 100.0 # km/s / (Mpc/h) + +#: Default initial scale factor for LPT kick in N-body simulations. +_A_INITIAL = 0.1 + +#: Default number of k samples for power spectrum interpolation table. +_N_K_INTERP = 128 + + +def _require_jaxpm(fn_name): + """Raise ImportError with a helpful message when jaxpm/jax_cosmo are absent.""" + if not jaxpm_installed: + raise ImportError( + f"'{fn_name}' requires jaxpm, jax_cosmo and diffrax. " + "Install them with: pip install jaxpm jax_cosmo diffrax" + ) + + +def get_cosmology( + omega_m, + sigma8, + h=0.6774, + omega_b=0.0486, + n_s=0.9667, + w0=-1.0, + wa=0.0, + omega_k=0.0, +): + """Create a ``jax_cosmo.Cosmology`` object from standard parameters. + + Args: + omega_m (float): Total matter density parameter Omega_m. + sigma8 (float): RMS matter density fluctuation on 8 Mpc/h scales. + h (float): Dimensionless Hubble parameter H_0 / (100 km/s/Mpc). + Default 0.6774. + omega_b (float): Baryon density parameter Omega_b. Default 0.0486. + n_s (float): Spectral index of the primordial power spectrum. + Default 0.9667. + w0 (float): Dark energy equation of state at a=1. Default -1.0. + wa (float): Dark energy equation of state evolution parameter. + Default 0.0. + omega_k (float): Curvature density parameter. Default 0.0. + + Returns: + jax_cosmo.Cosmology: Cosmology instance compatible with JaxPM. + """ + _require_jaxpm("get_cosmology") + return jc.Cosmology( + h=h, + Omega_b=omega_b, + Omega_c=omega_m - omega_b, + w0=w0, + wa=wa, + n_s=n_s, + sigma8=sigma8, + Omega_k=omega_k, + ) + + +def _differentiable_linear_field(mesh_shape, box_size, pk_fn, seed): + """Generate a linear density field with a JAX-differentiable power spectrum. + + This replaces ``jaxpm.pm.linear_field`` with a numerically safe version + that avoids a NaN gradient at the DC mode (k=0). The DC mode of the + power spectrum is identically zero in any reasonable cosmology, but the + automatic derivative of ``jax_cosmo`` power spectra at k=0 is NaN. + This function handles this by replacing k=0 with a dummy value before + calling ``pk_fn``, and then zeroing the DC mode with a multiplicative + mask (which has zero gradient). + + Args: + mesh_shape (tuple[int, int, int]): Grid dimensions. + box_size (jnp.ndarray): Box size in Mpc/h, shape ``(3,)``. + pk_fn (callable): Power spectrum function P(k). + seed (jax.random.PRNGKey): Random seed. + + Returns: + jnp.ndarray: Real-space linear density field, shape ``mesh_shape``. + """ + # Draw Gaussian random Fourier coefficients + field = normal_field(seed=seed, shape=mesh_shape) + field = fft3d(field) + + # Wavenumber magnitude in units of 1/Mpc/h + kvec = fftk(field) + kmesh = sum( + (kk / box_size[i] * mesh_shape[i]) ** 2 for i, kk in enumerate(kvec) + ) ** 0.5 + + # Replace k=0 (DC mode) with 1.0 to avoid NaN in pk_fn at k=0. + # The DC contribution will be zeroed by dc_mask below. + kmesh_safe = jnp.where(kmesh > 0, kmesh, jnp.ones_like(kmesh)) + + # Dimensionless power spectrum amplitude on the mesh + volume = jnp.prod(jnp.array(mesh_shape)) / jnp.prod(box_size) + pkmesh = pk_fn(kmesh_safe) * volume + + # Multiplicative DC mask: 1 for k>0, 0 for k=0. + # Using multiplication instead of jnp.where ensures the gradient at k=0 + # is 0 rather than NaN (both branches of jnp.where are evaluated in JAX). + dc_mask = (kmesh > 0).astype(jnp.float64) + + field = field * jnp.sqrt(pkmesh) * dc_mask + return ifft3d(field) + + +def _run_lpt(cosmo, initial_conditions, a): + """Run first-order Lagrangian Perturbation Theory (1LPT). + + This is a manual 1LPT implementation that returns particle displacements + and momenta without computing the force derivative ``dGfa``, which has a + known caching incompatibility in jaxpm when used with JAX-traced + cosmology objects. + + Args: + cosmo (jax_cosmo.Cosmology): Cosmological parameters. + initial_conditions (jnp.ndarray): Linear density field on the mesh, + shape ``mesh_shape``. + a (float): Scale factor at which to evaluate LPT. + + Returns: + tuple: + - dx (jnp.ndarray): Particle displacement from lattice positions + in mesh cell units, shape ``(*mesh_shape, 3)``. + - p (jnp.ndarray): Particle momentum in internal units + ``[cells * H_0]``, shape ``(*mesh_shape, 3)``. + """ + mesh_shape = initial_conditions.shape + a_arr = jnp.atleast_1d(a) + a_scalar = a_arr[0] + + E = jnp.sqrt(jc.background.Esqr(cosmo, a_arr))[0] + + # Start particles at lattice positions (zero displacement) + particles = jnp.zeros((*mesh_shape, 3)) + + # Compute gravitational force from the linear density field + delta_k = fft3d(initial_conditions) + initial_force = pm_forces(particles, delta=delta_k, paint_absolute_pos=False) + + # 1LPT displacement: dx = D1(a) * Psi (Zel'dovich approximation) + D1 = growth_factor(cosmo, a_arr)[0] + f1 = growth_rate(cosmo, a_arr)[0] + + dx = D1 * initial_force + # Momentum: p = a^2 * H(a) * f1 * dx (in internal units) + p = a_scalar**2 * f1 * E * dx + + return dx, p + + +def generate_density_and_velocity_lpt( + cosmo, + mesh_shape, + box_size, + seed, + a=1.0, +): + """Generate differentiable density and velocity fields using 1LPT. + + Runs a first-order Lagrangian Perturbation Theory simulation (Zel'dovich + approximation) using the JaxPM package and returns the density contrast + and peculiar velocity fields on a regular 3D Cartesian mesh. All + operations are JAX-differentiable with respect to ``cosmo``. + + This function is provided primarily for **testing** and rapid parameter + scans. For production use, prefer + :func:`generate_density_and_velocity_nbody` which evolves the simulation + with a full particle-mesh ODE solver. + + Args: + cosmo (jax_cosmo.Cosmology): Cosmological parameters. Create with + :func:`get_cosmology`. + mesh_shape (tuple[int, int, int]): Number of mesh cells along each + axis, e.g. ``(64, 64, 64)``. + box_size (array-like): Box dimensions in Mpc/h along each axis, + e.g. ``[256., 256., 256.]``. + seed (jax.random.PRNGKey): Random seed for the Gaussian initial + conditions. + a (float): Final scale factor at which to evaluate the fields. + Default 1.0 (z=0). + + Returns: + tuple: + - density_field (jnp.ndarray): Density contrast + :math:`\\delta(x) = \\rho/\\bar{\\rho} - 1` on the mesh, shape + ``mesh_shape``. The mean value is approximately zero. + - velocity_field (jnp.ndarray): Peculiar velocity field in km/s + on the mesh, shape ``(*mesh_shape, 3)`` with components + ``(vx, vy, vz)`` in Cartesian coordinates. + """ + _require_jaxpm("generate_density_and_velocity_lpt") + box_size = jnp.array(box_size) + + def linear_pk_fn(k): + return jc.power.linear_matter_power(cosmo, k, a=a) + + # Generate Gaussian random initial conditions from the linear power spectrum + initial_conditions = _differentiable_linear_field(mesh_shape, box_size, linear_pk_fn, seed) + + # Run 1LPT to get particle displacements and momenta + dx, p = _run_lpt(cosmo, initial_conditions, a) + + # Paint displaced particles to obtain density contrast field delta(x) = rho/rho_bar - 1 + density_field = cic_paint_dx(dx) - 1 + + # Convert momentum to velocity field in km/s + # p = a^2 * f * E * dx => v_dimensionless = p / (a^2 * E) = f * dx [cells] + # v_km_s = v_dimensionless * cell_size [Mpc/h] * H_0 [km/s / (Mpc/h)] + a_arr = jnp.atleast_1d(a) + E = jnp.sqrt(jc.background.Esqr(cosmo, a_arr))[0] + cell_size = box_size / jnp.array(mesh_shape, dtype=jnp.float64) + velocity_field = p / (a_arr[0] ** 2 * E) * cell_size * _H0_UNIT + + return density_field, velocity_field + + +def generate_density_and_velocity_nbody( + cosmo, + mesh_shape, + box_size, + seed, + a=1.0, + a_initial=_A_INITIAL, + ode_rtol=1e-4, + ode_atol=1e-4, + ode_dt0=0.01, + ode_max_steps=4096, +): + """Generate differentiable density and velocity fields using a full N-body simulation. + + Runs a particle-mesh N-body simulation via the following pipeline: + + 1. Generate Gaussian random initial conditions from the linear matter power + spectrum at ``a=1`` using :func:`_differentiable_linear_field`. + 2. Apply first-order LPT to displace particles to ``a_initial`` (early + time, default 0.1) using :func:`_run_lpt`. + 3. Evolve the particle positions and momenta from ``a_initial`` to ``a`` + by integrating the particle-mesh equations of motion with the + ``diffrax`` adaptive ODE solver (``Tsit5`` + ``PIDController``). + 4. Paint displaced particles onto the mesh to obtain the density contrast + field delta(x) using CIC mass assignment. + 5. Compute the peculiar velocity field from the final momenta. + + All operations are JAX-differentiable with respect to ``cosmo``. + + Args: + cosmo (jax_cosmo.Cosmology): Cosmological parameters. Create with + :func:`get_cosmology`. + mesh_shape (tuple[int, int, int]): Number of mesh cells along each + axis, e.g. ``(64, 64, 64)``. + box_size (array-like): Box dimensions in Mpc/h along each axis, + e.g. ``[256., 256., 256.]``. + seed (jax.random.PRNGKey): Random seed for the Gaussian initial + conditions. + a (float): Final scale factor at which to evaluate the fields. + Default 1.0 (z=0). + a_initial (float): Scale factor at which the LPT initial conditions + are set. The ODE is integrated from ``a_initial`` to ``a``. + Default 0.1 (z≈9). + ode_rtol (float): Relative tolerance for the adaptive ODE integrator. + Default 1e-4. + ode_atol (float): Absolute tolerance for the adaptive ODE integrator. + Default 1e-4. + ode_dt0 (float): Initial step size for the ODE integrator. Default + 0.01. + ode_max_steps (int): Maximum number of ODE integration steps. + Default 4096. + + Returns: + tuple: + - density_field (jnp.ndarray): Density contrast + :math:`\\delta(x) = \\rho/\\bar{\\rho} - 1` on the mesh, shape + ``mesh_shape``. The mean value is approximately zero. + - velocity_field (jnp.ndarray): Peculiar velocity field in km/s + on the mesh, shape ``(*mesh_shape, 3)`` with components + ``(vx, vy, vz)`` in Cartesian coordinates. + + Raises: + ImportError: If ``jaxpm``, ``jax_cosmo`` or ``diffrax`` are not installed. + """ + _require_jaxpm("generate_density_and_velocity_nbody") + try: + from diffrax import ODETerm, PIDController, SaveAt, Tsit5, diffeqsolve + except ImportError as exc: + raise ImportError( + "The 'diffrax' package is required for N-body simulations. " + "Install it with: pip install diffrax" + ) from exc + + box_size = jnp.array(box_size) + + # Build a tabulated, JAX-differentiable power spectrum interpolant. + # Using jnp.interp instead of calling jc.power.linear_matter_power + # directly on every k-mesh evaluation is faster and avoids potential + # NaN gradients at k=0 when the direct evaluation is nested inside JIT. + k_table = jnp.logspace(-4, 1, _N_K_INTERP) + pk_table = jc.power.linear_matter_power(cosmo, k_table, a=1.0) + + def linear_pk_fn(k): + return jnp.interp(k.reshape([-1]), k_table, pk_table).reshape(k.shape) + + # 1. Gaussian random initial conditions (present-day amplitude) + initial_conditions = _differentiable_linear_field(mesh_shape, box_size, linear_pk_fn, seed) + + # 2. 1LPT kick to the initial scale factor + dx, p = _run_lpt(cosmo, initial_conditions, a_initial) + + # 3. N-body ODE integration from a_initial to a + nbody_ode = make_diffrax_ode(mesh_shape, paint_absolute_pos=False) + ode_term = ODETerm(nbody_ode) + y0 = jnp.stack([dx, p], axis=0) + + res = diffeqsolve( + ode_term, + Tsit5(), + t0=a_initial, + t1=a, + dt0=ode_dt0, + y0=y0, + args=cosmo, + saveat=SaveAt(ts=jnp.array([a])), + stepsize_controller=PIDController(rtol=ode_rtol, atol=ode_atol), + max_steps=ode_max_steps, + ) + + # res.ys has shape (1, 2, *mesh_shape, 3): one snapshot, two fields (dx, p) + dx_final = res.ys[0, 0] + p_final = res.ys[0, 1] + + # 4. Paint particles to get density contrast delta(x) = rho/rho_bar - 1 + density_field = cic_paint_dx(dx_final) - 1 + + # 5. Convert momentum to velocity field in km/s + # p = a^2 * E(a) * f(a) * dx => v = p / (a^2 * E(a)) [cells] + # v_km_s = v [cells] * cell_size [Mpc/h] * H_0 [km/s / (Mpc/h)] + a_arr = jnp.atleast_1d(a) + E = jnp.sqrt(jc.background.Esqr(cosmo, a_arr))[0] + cell_size = box_size / jnp.array(mesh_shape, dtype=jnp.float64) + velocity_field = p_final / (a_arr[0] ** 2 * E) * cell_size * _H0_UNIT + + return density_field, velocity_field + + +def generate_density_and_velocity( + cosmo, + mesh_shape, + box_size, + seed, + a=1.0, + method="nbody", + **kwargs, +): + """Generate differentiable density and velocity fields. + + This is the primary entry point for simulation. By default it runs a + full particle-mesh N-body simulation via :func:`generate_density_and_velocity_nbody`. + Pass ``method='lpt'`` to use the faster but less accurate Zel'dovich + approximation (:func:`generate_density_and_velocity_lpt`), which is + useful for testing. + + Args: + cosmo (jax_cosmo.Cosmology): Cosmological parameters. Create with + :func:`get_cosmology`. + mesh_shape (tuple[int, int, int]): Number of mesh cells along each + axis, e.g. ``(64, 64, 64)``. + box_size (array-like): Box dimensions in Mpc/h along each axis, + e.g. ``[256., 256., 256.]``. + seed (jax.random.PRNGKey): Random seed for the Gaussian initial + conditions. + a (float): Final scale factor at which to evaluate the fields. + Default 1.0 (z=0). + method (str): Simulation method. One of ``"nbody"`` (default) or + ``"lpt"``. + **kwargs: Extra keyword arguments forwarded to the chosen simulation + function. For ``method='nbody'`` these include ``a_initial``, + ``ode_rtol``, ``ode_atol``, ``ode_dt0``, and ``ode_max_steps``. + + Returns: + tuple: + - density_field (jnp.ndarray): Density contrast + :math:`\\delta(x) = \\rho/\\bar{\\rho} - 1` on the mesh, shape + ``mesh_shape``. The mean value is approximately zero. + - velocity_field (jnp.ndarray): Peculiar velocity field in km/s, + shape ``(*mesh_shape, 3)``. + + Raises: + ValueError: If ``method`` is not ``"nbody"`` or ``"lpt"``. + """ + if method == "nbody": + return generate_density_and_velocity_nbody( + cosmo, mesh_shape, box_size, seed, a=a, **kwargs + ) + elif method == "lpt": + return generate_density_and_velocity_lpt( + cosmo, mesh_shape, box_size, seed, a=a, **kwargs + ) + else: + raise ValueError( + f"Unknown simulation method '{method}'. Choose 'nbody' or 'lpt'." + ) + + +def interpolate_velocity_to_positions(velocity_field, positions, box_size, mesh_shape): + """Interpolate velocity field at arbitrary Cartesian positions using CIC. + + Performs trilinear (CIC) interpolation of a 3D velocity field at the + provided positions. Periodic boundary conditions are applied. + + Args: + velocity_field (jnp.ndarray): Velocity field in km/s, shape + ``(*mesh_shape, 3)`` as returned by + :func:`generate_density_and_velocity`. + positions (jnp.ndarray): Galaxy Cartesian positions in Mpc/h, + shape ``(N, 3)``. Coordinates should be within ``[0, box_size)`` + along each axis. + box_size (array-like): Box dimensions in Mpc/h, shape ``(3,)``. + mesh_shape (array-like): Number of mesh cells per axis, shape ``(3,)``. + + Returns: + jnp.ndarray: Velocity vector in km/s at each position, shape + ``(N, 3)``. + """ + _require_jaxpm("interpolate_velocity_to_positions") + box_size = jnp.array(box_size) + mesh_shape_arr = jnp.array(mesh_shape, dtype=jnp.float64) + + # Convert Cartesian Mpc/h positions to mesh cell units [0, Ni) + pos_mesh = positions / box_size * mesh_shape_arr + + # Read each velocity component using CIC interpolation + velocities = jnp.stack( + [_cic_read(velocity_field[..., i], pos_mesh) for i in range(3)], + axis=-1, + ) + return velocities + + +def _cic_read(grid_mesh, positions): + """Read a 3D scalar field at arbitrary positions using CIC interpolation. + + This is a JAX-differentiable trilinear interpolation compatible with + arbitrary batch sizes (unlike jaxpm's ``cic_read`` which requires + positions to match the grid shape). + + Args: + grid_mesh (jnp.ndarray): 3D scalar field, shape ``(Nx, Ny, Nz)``. + positions (jnp.ndarray): Positions in mesh cell units ``[0, Ni)``, + shape ``(N, 3)``. + + Returns: + jnp.ndarray: Interpolated field values at each position, shape + ``(N,)``. + """ + # Add neighbour-offset dimension: positions (N, 1, 3) + pos = jnp.expand_dims(positions, -2) + + # 8 CIC neighbour offsets, shape (1, 8, 3) + offsets = jnp.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ], + dtype=jnp.float64, + ) + offsets = offsets[jnp.newaxis, ...] # (1, 8, 3) + + floor_pos = jnp.floor(pos) # (N, 1, 3) + neighbours = floor_pos + offsets # (N, 8, 3) + + # CIC kernel weights + kernel = 1.0 - jnp.abs(pos - neighbours) # (N, 8, 3) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] # (N, 8) + + # Periodic wrap of neighbour indices + grid_shape = jnp.array(grid_mesh.shape) + idx = jnp.mod(neighbours.astype(jnp.int32), grid_shape) # (N, 8, 3) + + # Gather and weight + values = grid_mesh[idx[..., 0], idx[..., 1], idx[..., 2]] # (N, 8) + return (values * kernel).sum(axis=-1) # (N,) + + +def _safe_normalize(positions): + """Compute unit vectors along the line-of-sight, handling zero-radius positions. + + Returns the unit vector for each position. When the Euclidean norm of a + position is zero (observer at the origin coincides with the galaxy), the + unit vector is set to zero to avoid division by zero. + + Args: + positions (jnp.ndarray): Cartesian positions in Mpc/h, shape ``(N, 3)``. + + Returns: + jnp.ndarray: Unit vectors along line-of-sight, shape ``(N, 3)``. + """ + r = jnp.linalg.norm(positions, axis=-1, keepdims=True) + return positions / jnp.where(r > 0, r, jnp.ones_like(r)) + + +def compute_los_velocity(velocities, positions): + """Project 3D peculiar velocities onto the line-of-sight direction. + + Computes the radial (line-of-sight) component of the peculiar velocity + for each galaxy. The observer is assumed to be at the Cartesian origin. + + Args: + velocities (jnp.ndarray): 3D peculiar velocities in km/s, shape + ``(N, 3)``. + positions (jnp.ndarray): Galaxy Cartesian positions in Mpc/h, shape + ``(N, 3)``. + + Returns: + jnp.ndarray: Line-of-sight peculiar velocity in km/s, shape ``(N,)``. + """ + _require_jaxpm("compute_los_velocity") + los_unit = _safe_normalize(positions) + return jnp.sum(velocities * los_unit, axis=-1) + + +def compute_fsigma8(cosmo, a=1.0): + """Compute the linear growth parameter f*sigma_8 from a cosmology. + + Evaluates :math:`f\\sigma_8 = f(a) \\cdot \\sigma_8`, where :math:`f` is + the logarithmic growth rate :math:`f = d\\ln D / d\\ln a` and + :math:`\\sigma_8` is the RMS matter fluctuation amplitude at + 8 Mpc/h. + + Args: + cosmo (jax_cosmo.Cosmology): Cosmological parameters. Create with + :func:`get_cosmology`. + a (float): Scale factor at which to evaluate the growth rate. + Default 1.0 (z=0). + + Returns: + float: :math:`f \\cdot \\sigma_8` dimensionless growth parameter. + """ + _require_jaxpm("compute_fsigma8") + a_arr = jnp.atleast_1d(a) + f = jc.background.growth_rate(cosmo, a_arr)[0] + return f * cosmo.sigma8 + + +def radec_to_cartesian(ra, dec, r_com): + """Convert spherical sky coordinates to Cartesian positions. + + Args: + ra (array-like): Right ascension in degrees. + dec (array-like): Declination in degrees. + r_com (array-like): Comoving distance in Mpc/h. + + Returns: + jnp.ndarray: Cartesian positions in Mpc/h, shape ``(N, 3)``. + """ + _require_jaxpm("radec_to_cartesian") + ra_rad = jnp.deg2rad(jnp.asarray(ra)) + dec_rad = jnp.deg2rad(jnp.asarray(dec)) + r = jnp.asarray(r_com) + x = r * jnp.cos(dec_rad) * jnp.cos(ra_rad) + y = r * jnp.cos(dec_rad) * jnp.sin(ra_rad) + z = r * jnp.sin(dec_rad) + return jnp.stack([x, y, z], axis=-1) diff --git a/flip/simulation/likelihood.py b/flip/simulation/likelihood.py new file mode 100644 index 0000000..8c23fa6 --- /dev/null +++ b/flip/simulation/likelihood.py @@ -0,0 +1,238 @@ +"""Gaussian likelihood for forward-model velocity field comparison. + +This module provides a differentiable Gaussian likelihood that compares a +simulated peculiar velocity field (generated with :mod:`flip.simulation.generate`) +to observed peculiar velocity measurements from a :class:`flip.data_vector.DataVector`. +All operations are implemented in JAX. + +Examples: + >>> import jax + >>> import jax.numpy as jnp + >>> import jax_cosmo as jc + >>> from flip.simulation import generate, likelihood + >>> # Build likelihood from a DataVector and galaxy positions + >>> lik = likelihood.VelocityFieldLikelihood( + ... data_vector=my_velocity_data_vector, + ... positions_cartesian=galaxy_xyz, + ... mesh_shape=(64, 64, 64), + ... box_size=[512., 512., 512.], + ... seed=jax.random.PRNGKey(0), + ... ) + >>> neg_log_lik = lik({"omega_m": 0.3, "sigma8": 0.8}) +""" + +from flip.simulation import generate +from flip.utils import create_log + +log = create_log() + +try: + import jax.numpy as jnp + import jax.scipy as jsc + + jax_installed = True + +except ImportError: + jax_installed = False + log.add( + "Install jax to use the simulation likelihood module", + level="warning", + ) + + +def log_likelihood_gaussian(simulated_velocity, observed_velocity, observed_variance): + """Compute the Gaussian log-likelihood between simulated and observed velocities. + + Evaluates the log-likelihood under independent Gaussian measurement errors + (diagonal noise covariance). When ``observed_variance`` is a 2-D matrix, + the full covariance is used via Cholesky factorisation. + + The diagonal form evaluates: + + .. math:: + + \\log\\mathcal{L} = -\\frac{1}{2}\\sum_i + \\left[\\frac{(v^{\\rm obs}_i - v^{\\rm sim}_i)^2}{\\sigma_i^2} + + \\log(2\\pi\\sigma_i^2)\\right] + + Args: + simulated_velocity (jnp.ndarray): Simulated line-of-sight velocities + in km/s, shape ``(N,)``. + observed_velocity (jnp.ndarray): Observed peculiar velocities in km/s, + shape ``(N,)``. + observed_variance (jnp.ndarray): Measurement (co)variance. Either a + 1-D array of shape ``(N,)`` for independent errors, or a 2-D + positive-definite matrix of shape ``(N, N)`` for correlated errors. + + Returns: + float: Log-likelihood value. + + Raises: + ImportError: If JAX is not installed. + """ + if not jax_installed: + raise ImportError( + "'log_likelihood_gaussian' requires jax. " + "Install it with: pip install jax" + ) + residual = observed_velocity - simulated_velocity + if observed_variance.ndim == 2: + # Full covariance: use Cholesky factorisation for numerical stability + chol = jsc.linalg.cho_factor(observed_variance) + logdet = 2.0 * jnp.sum(jnp.log(jnp.diag(chol[0]))) + chi2 = jnp.dot(residual, jsc.linalg.cho_solve(chol, residual)) + n = residual.size + return -0.5 * (chi2 + logdet + n * jnp.log(2.0 * jnp.pi)) + else: + # Diagonal covariance: independent Gaussian errors + logdet = jnp.sum(jnp.log(2.0 * jnp.pi * observed_variance)) + chi2 = jnp.sum(residual**2 / observed_variance) + return -0.5 * (chi2 + logdet) + + +class VelocityFieldLikelihood: + """Gaussian likelihood comparing a JaxPM simulation to observed velocities. + + Given cosmological parameters, this callable runs a forward simulation + (N-body ODE by default, or 1LPT when ``method='lpt'``), interpolates the + resulting velocity field at observed galaxy positions, projects onto the + line of sight, and returns the Gaussian log-likelihood of the observations. + + The likelihood is fully JAX-differentiable with respect to the input + cosmological parameters. + + Args: + data_vector (flip.data_vector.DataVector): Velocity data vector that + provides observed velocities and their measurement variances via + its ``give_data_and_variance()`` method. + positions_cartesian (array-like): Galaxy Cartesian positions in Mpc/h, + shape ``(N, 3)``. These should be defined in the same frame as + the simulation box (origin at box corner, box extends to + ``box_size``). + mesh_shape (tuple[int, int, int]): Number of simulation mesh cells per + axis, e.g. ``(64, 64, 64)``. + box_size (array-like): Simulation box dimensions in Mpc/h. + seed (jax.random.PRNGKey): Random seed for the initial conditions. + a (float): Scale factor at which to evaluate the fields. Default 1.0. + method (str): Simulation method forwarded to + :func:`~flip.simulation.generate.generate_density_and_velocity`. + Either ``"nbody"`` (default, full N-body ODE integration) or + ``"lpt"`` (faster Zel'dovich approximation, for testing). + fixed_cosmo_params (dict | None): Cosmological parameters that are + held fixed during optimization. These are merged with the + ``cosmo_params`` dict at each likelihood call, with + ``cosmo_params`` taking precedence. Useful for fixing ``omega_m`` + while fitting only ``sigma8``, for example. Default ``None``. + parameter_values_dict (dict | None): Additional parameters consumed + by the data vector (e.g. ``{"M_0": -19.3}`` for + ``VelFromHDres``). If ``None``, an empty dict is used. + **simulation_kwargs: Extra keyword arguments forwarded to the + simulation function (e.g. ``ode_rtol``, ``ode_atol`` for + ``method='nbody'``). + + Examples: + >>> lik = VelocityFieldLikelihood( + ... data_vector=vel_vec, + ... positions_cartesian=xyz, + ... mesh_shape=(32, 32, 32), + ... box_size=[256., 256., 256.], + ... seed=jax.random.PRNGKey(1), + ... ) + >>> neg_log_lik = lik({"omega_m": 0.3, "sigma8": 0.8}) + """ + + def __init__( + self, + data_vector, + positions_cartesian, + mesh_shape, + box_size, + seed, + a=1.0, + method="nbody", + fixed_cosmo_params=None, + parameter_values_dict=None, + **simulation_kwargs, + ): + self.data_vector = data_vector + if not jax_installed: + raise ImportError( + "'VelocityFieldLikelihood' requires jax. " + "Install it with: pip install jax" + ) + self.positions_cartesian = jnp.array(positions_cartesian) + self.mesh_shape = mesh_shape + self.box_size = jnp.array(box_size) + self.seed = seed + self.a = a + self.method = method + self.fixed_cosmo_params = fixed_cosmo_params or {} + self.simulation_kwargs = simulation_kwargs + self.parameter_values_dict = parameter_values_dict or {} + + # Pre-compute observed velocities and their measurement (co)variances + # from the data vector. This is done once at construction time. + observed_velocity, observed_variance = self.data_vector.give_data_and_variance( + self.parameter_values_dict + ) + self.observed_velocity = jnp.array(observed_velocity) + self.observed_variance = jnp.array(observed_variance) + + log.add( + f"VelocityFieldLikelihood: {len(self.observed_velocity)} " + f"velocity observations, mesh {mesh_shape}, " + f"box {list(box_size)} Mpc/h, method='{method}'." + ) + + def __call__(self, cosmo_params): + """Evaluate the negative log-likelihood for a set of cosmological parameters. + + Runs the full JAX-differentiable forward model: + + 1. Build cosmology from ``cosmo_params`` merged with + ``fixed_cosmo_params``. + 2. Generate density and velocity fields via the configured simulation + method (N-body ODE or LPT). + 3. Interpolate velocity field at observed galaxy positions. + 4. Project onto line of sight. + 5. Compute Gaussian log-likelihood and return its negation. + + Args: + cosmo_params (dict): Cosmological parameters to optimize, accepted + by :func:`~flip.simulation.generate.get_cosmology`. These are + merged with ``fixed_cosmo_params`` (``cosmo_params`` takes + precedence), so only the free parameters need to be provided. + + Returns: + float: Negative log-likelihood (suitable for minimization). + """ + # Merge fixed parameters with the free ones (free params take precedence) + full_params = {**self.fixed_cosmo_params, **cosmo_params} + cosmo = generate.get_cosmology(**full_params) + + _, velocity_field = generate.generate_density_and_velocity( + cosmo, + self.mesh_shape, + self.box_size, + self.seed, + a=self.a, + method=self.method, + **self.simulation_kwargs, + ) + + velocities_3d = generate.interpolate_velocity_to_positions( + velocity_field, + self.positions_cartesian, + self.box_size, + self.mesh_shape, + ) + + simulated_los = generate.compute_los_velocity( + velocities_3d, self.positions_cartesian + ) + + return -log_likelihood_gaussian( + simulated_los, + self.observed_velocity, + self.observed_variance, + ) diff --git a/notebook/fit_simulation_velocity.ipynb b/notebook/fit_simulation_velocity.ipynb new file mode 100644 index 0000000..687e9df --- /dev/null +++ b/notebook/fit_simulation_velocity.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "md-intro", + "metadata": {}, + "source": "# Velocity fit of fσ₈ using the simulation forward model\n\nThis notebook demonstrates how to use the `flip.simulation` forward-model pipeline to\nfit the growth parameter **fσ₈** from a peculiar velocity catalogue.\n\nThe forward model uses a differentiable **particle-mesh simulation** (LPT or full N-body\nvia `JaxPM` + `diffrax`) and a gradient-based optimizer (`jaxopt`) to maximise the\nlikelihood of the observed velocities with respect to the cosmological parameters.\n\n**Outline:**\n1. Install dependencies\n2. Set up the simulation box and generate a synthetic mock catalogue\n3. Compute the true fσ₈ from the input cosmology\n4. Build the `VelocityFieldLikelihood` and scan the likelihood over σ8\n5. Fit σ8 (and fσ₈) with `SimulationFitter` using gradient descent\n6. Repeat with the full N-body ODE solver\n\n> **Note:** This notebook uses a small mesh (16³) to run in minutes on CPU.\n> For a science-quality analysis use at least a 128³ mesh." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-install", + "metadata": {}, + "outputs": [], + "source": "%%capture\n!pip install git+https://github.com/corentinravoux/flip \"jaxpm>=0.1\" \"diffrax>=0.5\" jax_cosmo jaxopt" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-imports", + "metadata": {}, + "outputs": [], + "source": "import jax\nimport jax.numpy as jnp\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom pathlib import Path\nfrom flip import data_vector, __flip_dir_path__\nfrom flip.simulation import generate, likelihood\nfrom flip.simulation.fitter import SimulationFitter\n\n# Enable 64-bit precision for accurate simulation\njax.config.update(\"jax_enable_x64\", True)\n\nflip_base = Path(__flip_dir_path__)\ndata_path = flip_base / \"data\"\nplt.style.use(data_path / \"style.mplstyle\")\n\nprint(\"JAX devices:\", jax.devices())" + }, + { + "cell_type": "markdown", + "id": "md-setup-md", + "metadata": {}, + "source": "## 1. Simulation box setup\n\nChoose a box large enough to contain the galaxy sample.\nWe use a 16³ mesh for speed; a production analysis would use 128³ or larger." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-box-setup", + "metadata": {}, + "outputs": [], + "source": "# -----------------------------------------------------------------------\n# Simulation parameters\n# -----------------------------------------------------------------------\nMESH_SHAPE = (16, 16, 16) # grid cells per axis (increase for science)\nBOX_SIZE = [128., 128., 128.] # Mpc/h per axis\nSEED = jax.random.PRNGKey(0)\n\n# -----------------------------------------------------------------------\n# True (fiducial) cosmology used to create the mock catalogue\n# -----------------------------------------------------------------------\nTRUE_OMEGA_M = 0.3\nTRUE_SIGMA8 = 0.8\n\ncosmo_true = generate.get_cosmology(omega_m=TRUE_OMEGA_M, sigma8=TRUE_SIGMA8)\nprint(\"True cosmology:\")\nprint(f\" Omega_m = {TRUE_OMEGA_M}\")\nprint(f\" sigma_8 = {float(cosmo_true.sigma8):.4f}\")" + }, + { + "cell_type": "markdown", + "id": "md-mock-md", + "metadata": {}, + "source": "## 2. Generate a synthetic mock velocity catalogue\n\nWe simulate the velocity field with LPT and paint mock galaxies at random positions.\nEach galaxy receives the true simulated velocity plus Gaussian measurement noise." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-mock-gen", + "metadata": {}, + "outputs": [], + "source": "# -----------------------------------------------------------------------\n# True velocity field (LPT for mock generation)\n# -----------------------------------------------------------------------\n_, vel_field_true = generate.generate_density_and_velocity_lpt(\n cosmo_true, MESH_SHAPE, BOX_SIZE, SEED\n)\n\n# -----------------------------------------------------------------------\n# Random galaxy positions inside the box [Mpc/h]\n# -----------------------------------------------------------------------\nN_GALAXIES = 150\nVELOCITY_ERROR = 200.0 # km/s per galaxy\n\nrng = np.random.RandomState(42)\npositions = rng.uniform(5.0, float(BOX_SIZE[0]) - 5.0, (N_GALAXIES, 3))\n\n# Interpolate true velocity at galaxy positions and project onto LOS\nvel_3d_true = generate.interpolate_velocity_to_positions(\n vel_field_true, jnp.array(positions), jnp.array(BOX_SIZE), MESH_SHAPE\n)\nlos_vel_true = generate.compute_los_velocity(vel_3d_true, jnp.array(positions))\n\n# Add Gaussian measurement noise\nnoise = rng.normal(0.0, VELOCITY_ERROR, N_GALAXIES)\nobserved_velocities = np.array(los_vel_true) + noise\n\nprint(f\"Mock catalogue: {N_GALAXIES} galaxies\")\nprint(f\"True velocity range: [{float(los_vel_true.min()):.1f}, {float(los_vel_true.max()):.1f}] km/s\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-visualise", + "metadata": {}, + "outputs": [], + "source": "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n\nsc = axes[0].scatter(\n positions[:, 0], positions[:, 1],\n c=np.array(los_vel_true), vmin=-400, vmax=400, cmap=\"RdBu_r\"\n)\naxes[0].set_xlabel(\"X [Mpc/h]\")\naxes[0].set_ylabel(\"Y [Mpc/h]\")\naxes[0].set_title(\"True line-of-sight velocity\")\nplt.colorbar(sc, ax=axes[0], label=\"v_los [km/s]\")\n\nsc2 = axes[1].scatter(\n positions[:, 0], positions[:, 1],\n c=observed_velocities, vmin=-600, vmax=600, cmap=\"RdBu_r\"\n)\naxes[1].set_xlabel(\"X [Mpc/h]\")\naxes[1].set_ylabel(\"Y [Mpc/h]\")\naxes[1].set_title(\"Observed velocity (true + noise)\")\nplt.colorbar(sc2, ax=axes[1], label=\"v_obs [km/s]\")\n\nplt.tight_layout()\nplt.show()" + }, + { + "cell_type": "markdown", + "id": "md-fs8-md", + "metadata": {}, + "source": "## 3. True fσ₈\n\nWe compute the true fσ₈ from the input cosmology as a reference for the fit." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-fs8", + "metadata": {}, + "outputs": [], + "source": "fsigma8_true = float(generate.compute_fsigma8(cosmo_true, a=1.0))\nprint(f\"True f*sigma_8 = {fsigma8_true:.4f}\")\nprint(f\" (f = {fsigma8_true / TRUE_SIGMA8:.4f}, sigma_8 = {TRUE_SIGMA8})\")" + }, + { + "cell_type": "markdown", + "id": "md-lik-md", + "metadata": {}, + "source": "## 4. Build the DataVector and VelocityFieldLikelihood\n\nWe wrap the mock observations in a `DirectVel` data vector (following the flip naming\nconventions) and pass it to `VelocityFieldLikelihood`. We fix `omega_m` at its true\nvalue and keep `sigma8` as the single free parameter." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-lik-build", + "metadata": {}, + "outputs": [], + "source": "# Build the flip DataVector\nvel_data = {\n \"velocity\": observed_velocities,\n \"velocity_error\": np.full(N_GALAXIES, VELOCITY_ERROR),\n}\nDataVel = data_vector.DirectVel(vel_data)\n\n# Build the simulation likelihood — fix omega_m, keep sigma8 free\nlik_lpt = likelihood.VelocityFieldLikelihood(\n data_vector=DataVel,\n positions_cartesian=positions,\n mesh_shape=MESH_SHAPE,\n box_size=BOX_SIZE,\n seed=SEED,\n method=\"lpt\",\n fixed_cosmo_params={\"omega_m\": TRUE_OMEGA_M},\n)\n\nval_true = lik_lpt({\"sigma8\": TRUE_SIGMA8})\nprint(f\"Neg-log-likelihood at true sigma_8 = {TRUE_SIGMA8}: {float(val_true):.3f}\")" + }, + { + "cell_type": "markdown", + "id": "md-scan-md", + "metadata": {}, + "source": "## 5. Likelihood scan over σ8\n\nWe evaluate the likelihood at a coarse grid of σ8 values to visualise the shape of\nthe posterior before running the gradient-based optimizer." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-scan", + "metadata": {}, + "outputs": [], + "source": "sigma8_grid = np.linspace(0.4, 1.2, 10)\nneg_log_lik_grid = [float(lik_lpt({\"sigma8\": s8})) for s8 in sigma8_grid]\n\n# Normalise to get Delta log L\nlog_lik_grid = -np.array(neg_log_lik_grid)\nlog_lik_grid -= log_lik_grid.max()\n\nplt.figure(figsize=(7, 4))\nplt.plot(sigma8_grid, log_lik_grid, \"o-\", label=\"log L (normalised)\")\nplt.axvline(TRUE_SIGMA8, color=\"r\", ls=\"--\", label=f\"True σ₈ = {TRUE_SIGMA8}\")\nplt.xlabel(r\"$\\sigma_8$\")\nplt.ylabel(r\"$\\Delta \\log \\mathcal{L}$\")\nplt.title(\"Likelihood scan over σ8 (LPT forward model)\")\nplt.legend()\nplt.tight_layout()\nplt.show()\n\nbest_scan = sigma8_grid[np.argmax(log_lik_grid)]\nprint(f\"Best σ8 from coarse scan: {best_scan:.3f} (true: {TRUE_SIGMA8})\")" + }, + { + "cell_type": "markdown", + "id": "md-fit-md", + "metadata": {}, + "source": "## 6. Gradient-based fit with SimulationFitter (LPT)\n\nWe use `SimulationFitter` to find the maximum-likelihood σ8 via gradient descent.\nThe gradient flows through the entire JAX simulation graph (auto-differentiation)." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-fit-lpt", + "metadata": {}, + "outputs": [], + "source": "fitter_lpt = SimulationFitter(\n likelihood=lik_lpt,\n initial_params={\"sigma8\": 0.7},\n bounds=({\"sigma8\": 0.2}, {\"sigma8\": 1.6}),\n solver=\"LBFGSB\",\n maxiter=30,\n)\n\nbest_lpt = fitter_lpt.run()\nbest_s8_lpt = float(best_lpt[\"sigma8\"])\n\ncosmo_lpt = generate.get_cosmology(omega_m=TRUE_OMEGA_M, sigma8=best_s8_lpt)\nfs8_lpt = float(generate.compute_fsigma8(cosmo_lpt, a=1.0))\n\nprint(\"\\n=== Best-fit results (LPT forward model) ===\")\nprint(f\" Best-fit σ8 = {best_s8_lpt:.4f} (true: {TRUE_SIGMA8:.4f})\")\nprint(f\" Best-fit f*σ8 = {fs8_lpt:.4f} (true: {fsigma8_true:.4f})\")\nprint(f\" Relative error = {abs(best_s8_lpt - TRUE_SIGMA8) / TRUE_SIGMA8 * 100:.1f}%\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-plot-fit", + "metadata": {}, + "outputs": [], + "source": "plt.figure(figsize=(7, 4))\nplt.plot(sigma8_grid, log_lik_grid, \"o-\", label=\"log L scan (LPT)\")\nplt.axvline(TRUE_SIGMA8, color=\"r\", ls=\"--\", label=f\"True σ₈ = {TRUE_SIGMA8}\")\nplt.axvline(best_s8_lpt, color=\"g\", ls=\"-.\",\n label=f\"Best-fit σ₈ = {best_s8_lpt:.3f}\")\nplt.xlabel(r\"$\\sigma_8$\")\nplt.ylabel(r\"$\\Delta \\log \\mathcal{L}$\")\nplt.title(\"σ8 recovery – LPT forward model\")\nplt.legend()\nplt.tight_layout()\nplt.show()\n\nprint(f\"True fσ₈ = {fsigma8_true:.4f}\")\nprint(f\"Fit fσ₈ = {fs8_lpt:.4f}\")" + }, + { + "cell_type": "markdown", + "id": "md-nbody-md", + "metadata": {}, + "source": "## 7. Full N-body simulation\n\nReplacing `method='lpt'` with `method='nbody'` evolves the simulation with the full\nparticle-mesh N-body ODE integrator (`diffrax` backend). This is more accurate at\nthe cost of longer run time.\n\n> **Tip:** Tighten `ode_rtol` / `ode_atol` to `1e-5` for a production run.\n> On GPU this runs ~10× faster than on CPU." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-nbody-lik", + "metadata": {}, + "outputs": [], + "source": "# Build N-body likelihood with the same data and positions\nlik_nbody = likelihood.VelocityFieldLikelihood(\n data_vector=DataVel,\n positions_cartesian=positions,\n mesh_shape=MESH_SHAPE,\n box_size=BOX_SIZE,\n seed=SEED,\n method=\"nbody\",\n fixed_cosmo_params={\"omega_m\": TRUE_OMEGA_M},\n ode_rtol=1e-3, # loose tolerances for the demo\n ode_atol=1e-3,\n)\n\nval_nbody_true = float(lik_nbody({\"sigma8\": TRUE_SIGMA8}))\nprint(f\"N-body neg-log-lik at true σ8 = {val_nbody_true:.3f}\")\n\n# Gradient at the true cosmology\ngrad_nbody = float(jax.grad(lambda s8: lik_nbody({\"sigma8\": s8}))(TRUE_SIGMA8))\nprint(f\"N-body gradient d(-log L)/d(σ8) at true σ8 = {grad_nbody:.4f}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-nbody-fit", + "metadata": {}, + "outputs": [], + "source": "fitter_nbody = SimulationFitter(\n likelihood=lik_nbody,\n initial_params={\"sigma8\": 0.7},\n bounds=({\"sigma8\": 0.2}, {\"sigma8\": 1.6}),\n solver=\"LBFGSB\",\n maxiter=20,\n)\n\nbest_nbody = fitter_nbody.run()\nbest_s8_nbody = float(best_nbody[\"sigma8\"])\n\ncosmo_nbody = generate.get_cosmology(omega_m=TRUE_OMEGA_M, sigma8=best_s8_nbody)\nfs8_nbody = float(generate.compute_fsigma8(cosmo_nbody, a=1.0))\n\nprint(\"\\n=== Best-fit results (N-body forward model) ===\")\nprint(f\" Best-fit σ8 = {best_s8_nbody:.4f} (true: {TRUE_SIGMA8:.4f})\")\nprint(f\" Best-fit f*σ8 = {fs8_nbody:.4f} (true: {fsigma8_true:.4f})\")\nprint(f\" Relative error = {abs(best_s8_nbody - TRUE_SIGMA8) / TRUE_SIGMA8 * 100:.1f}%\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "code-compare", + "metadata": {}, + "outputs": [], + "source": "methods = [\"LPT\", \"N-body\"]\nfs8_values = [fs8_lpt, fs8_nbody]\n\nplt.figure(figsize=(5, 4))\nplt.bar(methods, fs8_values, width=0.4, color=[\"steelblue\", \"coral\"])\nplt.axhline(fsigma8_true, color=\"r\", ls=\"--\", label=f\"True fσ₈ = {fsigma8_true:.4f}\")\nplt.ylabel(r\"$f\\sigma_8$\")\nplt.title(r\"Best-fit $f\\sigma_8$ by method\")\nplt.legend()\nplt.tight_layout()\nplt.show()\n\nprint(f\"\\nSummary:\")\nprint(f\" True f*σ8 = {fsigma8_true:.4f}\")\nprint(f\" LPT f*σ8 = {fs8_lpt:.4f}\")\nprint(f\" Nbody f*σ8 = {fs8_nbody:.4f}\")" + }, + { + "cell_type": "markdown", + "id": "md-summary-md", + "metadata": {}, + "source": "## 8. Summary\n\n| Method | Best-fit σ8 | Best-fit fσ₈ | True fσ₈ |\n|--------|------------|--------------|----------|\n| LPT | (see above) | (see above) | (see above) |\n| N-body | (see above) | (see above) | (see above) |\n\n*(Values are filled in after running the cells above.)*\n\n**Tips for a science run:**\n* Increase `MESH_SHAPE` to `(128, 128, 128)` or larger.\n* Use `method='nbody'` with `ode_rtol=1e-5`, `ode_atol=1e-5`.\n* Fix `omega_m` from CMB/BAO or jointly fit `omega_m` and `sigma8`.\n* Increase `maxiter` in `SimulationFitter` until convergence.\n* Use a real galaxy peculiar velocity catalogue." + } + ], + "metadata": { + "kernelspec": { + "display_name": "corentin", + "language": "python", + "name": "corentin" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f4fcb8d..972fb74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,15 @@ docs = [ "sphinx-autoapi", ] +simulation = [ + "jax", + "jaxlib", + "jaxpm", + "jaxopt", + "jax_cosmo", + "diffrax", +] + [tool.setuptools.dynamic] version = {attr = "flip.__version__"} diff --git a/test/test_simulation.py b/test/test_simulation.py new file mode 100644 index 0000000..2732821 --- /dev/null +++ b/test/test_simulation.py @@ -0,0 +1,339 @@ +"""Tests for the flip.simulation package. + +Covers generate.py (LPT and N-body field generation), likelihood.py +(VelocityFieldLikelihood), and fitter.py (SimulationFitter). + +All tests use a small mesh (8^3) and LPT mode to keep execution fast. +The N-body (diffrax) pipeline is exercised in a smoke-test that checks shapes +and finiteness without running a full optimisation. +""" + +import numpy as np +import pytest + +jax = pytest.importorskip("jax", reason="jax is required for simulation tests") +jnp = pytest.importorskip("jax.numpy", reason="jax is required for simulation tests") +pytest.importorskip("jaxpm", reason="jaxpm is required for simulation tests") +pytest.importorskip("jax_cosmo", reason="jax_cosmo is required for simulation tests") + +from flip import data_vector +from flip.simulation import generate, likelihood +from flip.simulation.fitter import SimulationFitter + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +jax.config.update("jax_enable_x64", True) + +_MESH_SHAPE = (8, 8, 8) +_BOX_SIZE = [64.0, 64.0, 64.0] +_SEED = jax.random.PRNGKey(0) +_TRUE_OMEGA_M = 0.3 +_TRUE_SIGMA8 = 0.8 + + +def _make_mock_data_vector(n_galaxies=20, seed_np=42): + """Return (positions, DirectVel data vector) for a small synthetic catalogue. + + The observed velocities are drawn from a N(0, 300) km/s distribution with + a flat error of 200 km/s to make the likelihood well-conditioned. + """ + rng = np.random.RandomState(seed_np) + positions = rng.uniform(5.0, 59.0, (n_galaxies, 3)) + velocities = rng.normal(0.0, 300.0, n_galaxies) + velocity_errors = np.full(n_galaxies, 200.0) + vel_data = {"velocity": velocities, "velocity_error": velocity_errors} + dv = data_vector.DirectVel(vel_data) + return positions, dv + + +# --------------------------------------------------------------------------- +# generate.py tests +# --------------------------------------------------------------------------- + + +class TestGetCosmology: + def test_returns_cosmology_object(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + assert hasattr(cosmo, "sigma8") + assert float(cosmo.sigma8) == pytest.approx(_TRUE_SIGMA8) + + def test_omega_c_derived_from_omega_m(self): + cosmo = generate.get_cosmology(omega_m=0.3, sigma8=0.8, omega_b=0.05) + assert float(cosmo.Omega_c) == pytest.approx(0.3 - 0.05, abs=1e-6) + + +class TestComputeFsigma8: + def test_fsigma8_finite_and_positive(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + fs8 = generate.compute_fsigma8(cosmo, a=1.0) + assert jnp.isfinite(fs8) + assert float(fs8) > 0.0 + + def test_fsigma8_scales_with_sigma8(self): + cosmo_hi = generate.get_cosmology(omega_m=0.3, sigma8=1.0) + cosmo_lo = generate.get_cosmology(omega_m=0.3, sigma8=0.5) + fs8_hi = float(generate.compute_fsigma8(cosmo_hi)) + fs8_lo = float(generate.compute_fsigma8(cosmo_lo)) + assert fs8_hi > fs8_lo + + def test_fsigma8_differentiable(self): + def obj(sigma8): + cosmo = generate.get_cosmology(omega_m=0.3, sigma8=sigma8) + return generate.compute_fsigma8(cosmo) + + grad = jax.grad(obj)(0.8) + assert jnp.isfinite(grad) + + +class TestGenerateLpt: + def test_output_shapes(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + density, velocity = generate.generate_density_and_velocity_lpt( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED + ) + assert density.shape == _MESH_SHAPE + assert velocity.shape == (*_MESH_SHAPE, 3) + + def test_fields_finite(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + density, velocity = generate.generate_density_and_velocity_lpt( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED + ) + assert jnp.all(jnp.isfinite(density)) + assert jnp.all(jnp.isfinite(velocity)) + + def test_density_mean_near_zero(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + density, _ = generate.generate_density_and_velocity_lpt( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED + ) + assert float(density.mean()) == pytest.approx(0.0, abs=0.1) + + def test_gradient_through_lpt_wrt_sigma8(self): + def obj(sigma8): + cosmo = generate.get_cosmology(omega_m=0.3, sigma8=sigma8) + _, vel = generate.generate_density_and_velocity_lpt( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED + ) + return (vel**2).sum() + + grad = jax.grad(obj)(0.8) + assert jnp.isfinite(grad) + + +class TestGenerateNbody: + def test_output_shapes(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + density, velocity = generate.generate_density_and_velocity_nbody( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED, + ode_rtol=1e-3, ode_atol=1e-3, + ) + assert density.shape == _MESH_SHAPE + assert velocity.shape == (*_MESH_SHAPE, 3) + + def test_fields_finite(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + density, velocity = generate.generate_density_and_velocity_nbody( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED, + ode_rtol=1e-3, ode_atol=1e-3, + ) + assert jnp.all(jnp.isfinite(density)) + assert jnp.all(jnp.isfinite(velocity)) + + +class TestGenerateDispatch: + def test_lpt_method_matches_lpt_function(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + d1, v1 = generate.generate_density_and_velocity_lpt( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED + ) + d2, v2 = generate.generate_density_and_velocity( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED, method="lpt" + ) + np.testing.assert_array_equal(np.array(d1), np.array(d2)) + np.testing.assert_array_equal(np.array(v1), np.array(v2)) + + def test_invalid_method_raises(self): + cosmo = generate.get_cosmology(omega_m=0.3, sigma8=0.8) + with pytest.raises(ValueError, match="Unknown simulation method"): + generate.generate_density_and_velocity( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED, method="invalid" + ) + + +class TestInterpolateAndLosVelocity: + def test_interpolate_output_shape(self): + cosmo = generate.get_cosmology(omega_m=_TRUE_OMEGA_M, sigma8=_TRUE_SIGMA8) + _, vel_field = generate.generate_density_and_velocity_lpt( + cosmo, _MESH_SHAPE, _BOX_SIZE, _SEED + ) + rng = np.random.RandomState(1) + positions = jnp.array(rng.uniform(5.0, 59.0, (30, 3))) + vel_at_pos = generate.interpolate_velocity_to_positions( + vel_field, positions, _BOX_SIZE, _MESH_SHAPE + ) + assert vel_at_pos.shape == (30, 3) + assert jnp.all(jnp.isfinite(vel_at_pos)) + + def test_los_velocity_output_shape(self): + rng = np.random.RandomState(2) + velocities = jnp.array(rng.normal(0.0, 200.0, (20, 3))) + positions = jnp.array(rng.uniform(5.0, 59.0, (20, 3))) + los_vel = generate.compute_los_velocity(velocities, positions) + assert los_vel.shape == (20,) + assert jnp.all(jnp.isfinite(los_vel)) + + +# --------------------------------------------------------------------------- +# likelihood.py tests +# --------------------------------------------------------------------------- + + +class TestVelocityFieldLikelihood: + def _build_lik(self, n_gal=15, method="lpt"): + positions, dv = _make_mock_data_vector(n_galaxies=n_gal) + lik = likelihood.VelocityFieldLikelihood( + data_vector=dv, + positions_cartesian=positions, + mesh_shape=_MESH_SHAPE, + box_size=_BOX_SIZE, + seed=_SEED, + method=method, + fixed_cosmo_params={"omega_m": _TRUE_OMEGA_M}, + ) + return lik + + def test_returns_finite_scalar(self): + lik = self._build_lik() + val = lik({"sigma8": _TRUE_SIGMA8}) + assert jnp.isfinite(val) + + def test_returns_positive_neg_log_lik(self): + lik = self._build_lik() + val = lik({"sigma8": _TRUE_SIGMA8}) + assert np.isscalar(float(val)) and jnp.isfinite(val) + + def test_gradient_wrt_sigma8_finite(self): + lik = self._build_lik() + grad = jax.grad(lambda s8: lik({"sigma8": s8}))(_TRUE_SIGMA8) + assert jnp.isfinite(grad) + + def test_fixed_cosmo_params_merged(self): + """fixed_cosmo_params must be used when cosmo_params omits omega_m.""" + positions, dv = _make_mock_data_vector(n_galaxies=10) + lik = likelihood.VelocityFieldLikelihood( + data_vector=dv, + positions_cartesian=positions, + mesh_shape=_MESH_SHAPE, + box_size=_BOX_SIZE, + seed=_SEED, + method="lpt", + fixed_cosmo_params={"omega_m": 0.3}, + ) + # Should not raise even though omega_m is not in cosmo_params + val = lik({"sigma8": 0.8}) + assert jnp.isfinite(val) + + def test_full_cosmo_params_without_fixed(self): + """Passing all params directly also works (no fixed_cosmo_params).""" + positions, dv = _make_mock_data_vector(n_galaxies=10) + lik = likelihood.VelocityFieldLikelihood( + data_vector=dv, + positions_cartesian=positions, + mesh_shape=_MESH_SHAPE, + box_size=_BOX_SIZE, + seed=_SEED, + method="lpt", + ) + val = lik({"omega_m": 0.3, "sigma8": 0.8}) + assert jnp.isfinite(val) + + def test_nbody_method_runs(self): + lik = self._build_lik(n_gal=10, method="nbody") + val = lik({"sigma8": _TRUE_SIGMA8}) + assert jnp.isfinite(val) + + +# --------------------------------------------------------------------------- +# fitter.py tests +# --------------------------------------------------------------------------- + + +class TestSimulationFitter: + def _build_lik_and_fitter(self, solver="LBFGSB", maxiter=3): + positions, dv = _make_mock_data_vector(n_galaxies=20) + lik = likelihood.VelocityFieldLikelihood( + data_vector=dv, + positions_cartesian=positions, + mesh_shape=_MESH_SHAPE, + box_size=_BOX_SIZE, + seed=_SEED, + method="lpt", + fixed_cosmo_params={"omega_m": _TRUE_OMEGA_M}, + ) + fitter = SimulationFitter( + likelihood=lik, + initial_params={"sigma8": 0.8}, + bounds=({"sigma8": 0.3}, {"sigma8": 1.5}), + solver=solver, + maxiter=maxiter, + ) + return lik, fitter + + def test_run_returns_dict(self): + _, fitter = self._build_lik_and_fitter() + result = fitter.run() + assert isinstance(result, dict) + assert "sigma8" in result + + def test_best_sigma8_in_bounds(self): + _, fitter = self._build_lik_and_fitter(maxiter=5) + result = fitter.run() + assert 0.3 <= float(result["sigma8"]) <= 1.5 + + def test_result_attribute_set_after_run(self): + _, fitter = self._build_lik_and_fitter() + assert fitter.result is None + fitter.run() + assert fitter.result is not None + + def test_invalid_solver_raises(self): + positions, dv = _make_mock_data_vector(n_galaxies=5) + lik = likelihood.VelocityFieldLikelihood( + data_vector=dv, + positions_cartesian=positions, + mesh_shape=_MESH_SHAPE, + box_size=_BOX_SIZE, + seed=_SEED, + method="lpt", + fixed_cosmo_params={"omega_m": _TRUE_OMEGA_M}, + ) + with pytest.raises(ValueError, match="Solver"): + SimulationFitter( + likelihood=lik, + initial_params={"sigma8": 0.8}, + solver="NOT_A_SOLVER", + ) + + def test_lbfgs_unconstrained_solver(self): + positions, dv = _make_mock_data_vector(n_galaxies=10) + lik = likelihood.VelocityFieldLikelihood( + data_vector=dv, + positions_cartesian=positions, + mesh_shape=_MESH_SHAPE, + box_size=_BOX_SIZE, + seed=_SEED, + method="lpt", + fixed_cosmo_params={"omega_m": _TRUE_OMEGA_M}, + ) + fitter = SimulationFitter( + likelihood=lik, + initial_params={"sigma8": 0.8}, + solver="LBFGS", + maxiter=3, + ) + result = fitter.run() + assert "sigma8" in result + assert jnp.isfinite(result["sigma8"])