Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions exponax/stepper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Dispersion
- HyperDiffusion
- Wave
- KleinGordon
- Burgers
- KortewegDeVries
- KuramotoSivashinsky
Expand Down Expand Up @@ -61,6 +62,10 @@
The Wave stepper uses a handcrafted diagonalization in Fourier space specific to
the wave equation. It has no corresponding generic stepper.

The KleinGordon stepper extends the Wave stepper with a mass term, using the
Klein-Gordon dispersion relation ω(k) = √(c²|k|² + m²). Setting m=0 recovers
the wave equation.

In the reaction submodule you find specific steppers that are special cases of
the GeneralPolynomialStepper, e.g., the FisherKPPStepper.

Expand Down Expand Up @@ -91,6 +96,7 @@
NavierStokesVelocity,
NavierStokesVorticity,
)
from ._klein_gordon import KleinGordon
from ._wave import Wave

__all__ = [
Expand All @@ -100,6 +106,7 @@
"Dispersion",
"HyperDiffusion",
"Wave",
"KleinGordon",
"Burgers",
"KortewegDeVries",
"KuramotoSivashinsky",
Expand Down
178 changes: 178 additions & 0 deletions exponax/stepper/_klein_gordon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float

from .._base_stepper import BaseStepper
from .._spectral import build_scaled_wavenumbers
from ..nonlin_fun import ZeroNonlinearFun


class KleinGordon(BaseStepper):
mass: float
speed_of_sound: float
frequency: Float[Array, " 1 ... (N//2)+1"]

def __init__(
self,
num_spatial_dims: int,
domain_extent: float,
num_points: int,
dt: float,
*,
speed_of_sound: float = 1.0,
mass: float = 1.0,
):
"""
Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) Klein-Gordon
equation on periodic boundary conditions.

In 1d, the Klein-Gordon equation is given by

```
uₜₜ = c² uₓₓ - m² u
```

with `c ∈ ℝ` being the wave speed and `m ∈ ℝ` being the mass
parameter. This is the relativistic generalization of the wave
equation, fundamental to quantum field theory and lattice field
simulations.

In higher dimensions:

```
uₜₜ = c² Δu - m² u
```

**Dispersion relation:** ω(k) = √(c²|k|² + m²)

Unlike the wave equation (ω = c|k|), the Klein-Gordon equation has a
**mass gap** — no modes with ω < m exist. The group velocity
v_g = c²|k|/ω is always less than c (massive dispersion).

Internally, the same diagonalization approach as the
[`exponax.stepper.Wave`][] stepper is used, but with
the Klein-Gordon dispersion relation.

The second-order equation is rewritten as a first-order system:

```
hₜ = v
vₜ = c² Δh - m² h
```

In Fourier space, each wavenumber k oscillates at frequency
ω(k) = √(c²|k|² + m²). The system is diagonalized into
forward/backward traveling modes that each evolve as a pure
phase rotation.

**Arguments:**

- `num_spatial_dims`: The number of spatial dimensions `d`.
- `domain_extent`: The size of the domain `L`; in higher dimensions
the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
- `num_points`: The number of points `N` used to discretize the
domain. This **includes** the left boundary point and **excludes**
the right boundary point. In higher dimensions; the number of points
in each dimension is the same.
- `dt`: The timestep size `Δt` between two consecutive states.
- `speed_of_sound` (keyword-only): The wave speed `c`. Default: `1.0`.
- `mass` (keyword-only): The mass parameter `m`. Default: `1.0`.

**Notes:**

- The stepper is unconditionally stable, no matter the choice of
any argument because the equation is solved analytically in Fourier
space.
- Setting `mass = 0.0` recovers the standard wave equation.
- The factors `c Δt / L` and `m Δt` together affect the dynamics.
"""
self.speed_of_sound = speed_of_sound
self.mass = mass
wavenumber_norm = jnp.linalg.norm(
build_scaled_wavenumbers(
num_spatial_dims=num_spatial_dims,
domain_extent=domain_extent,
num_points=num_points,
),
axis=0,
keepdims=True,
)
# Klein-Gordon dispersion: ω(k) = sqrt(c²|k|² + m²)
self.frequency = jnp.sqrt(
speed_of_sound**2 * wavenumber_norm**2 + mass**2
)
super().__init__(
num_spatial_dims=num_spatial_dims,
domain_extent=domain_extent,
num_points=num_points,
dt=dt,
num_channels=2,
order=0,
)

def _forward_transform(
self, u_hat: Complex[Array, " 2 ... (N//2)+1"]
) -> Complex[Array, " 2 ... (N//2)+1"]:
"""Transform (h, v) into diagonalized Klein-Gordon wave modes."""
h_hat, v_hat = u_hat[0:1], u_hat[1:2]
# Scale height to match velocity units: w = iω h
omega_guard = jnp.where(self.frequency == 0, 1.0, self.frequency)
w_hat = 1j * omega_guard * h_hat

# Orthonormal rotation into wave modes
pos = (1 / jnp.sqrt(2)) * (w_hat + v_hat)
neg = (1 / jnp.sqrt(2)) * (w_hat - v_hat)
return jnp.concatenate([pos, neg], axis=0)

def _inverse_transform(
self, waves_hat: Complex[Array, " 2 ... (N//2)+1"]
) -> Complex[Array, " 2 ... (N//2)+1"]:
"""Transform diagonalized wave modes back into (h, v)."""
pos, neg = waves_hat[0:1], waves_hat[1:2]
# Inverse rotation
w_hat = (1 / jnp.sqrt(2)) * (pos + neg)
v_hat = (1 / jnp.sqrt(2)) * (pos - neg)

# Undo scaling to recover height
omega_guard = jnp.where(self.frequency == 0, 1.0, self.frequency)
h_hat = w_hat / (1j * omega_guard)
return jnp.concatenate([h_hat, v_hat], axis=0)

def _build_linear_operator(
self, derivative_operator: Complex[Array, " D ... (N//2)+1"]
) -> Complex[Array, " 2 ... (N//2)+1"]:
val = 1j * self.frequency
return jnp.concatenate(
(
val,
-val,
),
axis=0,
)

def _build_nonlinear_fun(
self, derivative_operator: Complex[Array, " D ... (N//2)+1"]
) -> ZeroNonlinearFun:
return ZeroNonlinearFun(self.num_spatial_dims, self.num_points)

def step_fourier(
self, u_hat: Complex[Array, " 2 ... (N//2)+1"]
) -> Complex[Array, " 2 ... (N//2)+1"]:
"""
Advance the state by one timestep in Fourier space.

Overrides the base method to wrap the ETDRK step with the
forward/inverse diagonalization transforms.
"""
waves_hat = self._forward_transform(u_hat)
waves_hat_next = super().step_fourier(waves_hat)
u_hat_next = self._inverse_transform(waves_hat_next)

# At k=0 with m>0, the system is still diagonalizable (ω(0) = m ≠ 0),
# so no special DC correction is needed. However, if m=0 we fall back
# to the wave equation DC behavior.
if self.mass == 0.0:
h_dc_idx = (0,) + (0,) * self.num_spatial_dims
v_dc_idx = (1,) + (0,) * self.num_spatial_dims
u_hat_next = u_hat_next.at[h_dc_idx].add(self.dt * u_hat[v_dc_idx])

return u_hat_next
1 change: 1 addition & 0 deletions tests/test_builtin_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_instantiate():
ex.stepper.Dispersion,
ex.stepper.HyperDiffusion,
ex.stepper.Wave,
ex.stepper.KleinGordon,
ex.stepper.Burgers,
ex.stepper.KuramotoSivashinsky,
ex.stepper.KuramotoSivashinskyConservative,
Expand Down
173 changes: 173 additions & 0 deletions tests/test_klein_gordon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import jax.numpy as jnp
import pytest

import exponax as ex
from exponax.stepper import KleinGordon, Wave

L = 2 * jnp.pi
PI = jnp.pi


# ===========================================================================
# Instantiation
# ===========================================================================


class TestKleinGordonInstantiation:
@pytest.mark.parametrize("num_spatial_dims", [1, 2, 3])
def test_instantiate(self, num_spatial_dims):
stepper = KleinGordon(num_spatial_dims, 10.0, 25, 0.1)
assert stepper.num_channels == 2
assert stepper.num_spatial_dims == num_spatial_dims

@pytest.mark.parametrize("num_spatial_dims", [1, 2, 3])
def test_output_shape(self, num_spatial_dims):
N = 16
stepper = KleinGordon(num_spatial_dims, L, N, 0.01)
u0 = jnp.zeros((2,) + (N,) * num_spatial_dims)
u1 = stepper(u0)
assert u1.shape == u0.shape
assert jnp.all(jnp.isfinite(u1))

def test_wrong_input_shape_raises(self):
stepper = KleinGordon(1, L, 32, 0.01)
with pytest.raises(ValueError, match="Expected shape"):
stepper(jnp.zeros((1, 32))) # needs 2 channels

def test_default_params(self):
stepper = KleinGordon(1, L, 32, 0.01)
assert stepper.speed_of_sound == 1.0
assert stepper.mass == 1.0

def test_custom_params(self):
stepper = KleinGordon(1, L, 32, 0.01, speed_of_sound=2.0, mass=3.0)
assert stepper.speed_of_sound == 2.0
assert stepper.mass == 3.0


# ===========================================================================
# mass=0 should recover Wave equation
# ===========================================================================


class TestKleinGordonRecoverWave:
"""When mass=0, KleinGordon must produce identical results to Wave."""

@pytest.mark.parametrize("num_spatial_dims", [1, 2])
def test_mass_zero_matches_wave(self, num_spatial_dims):
N, dt, c = 32, 0.01, 1.5
kg = KleinGordon(
num_spatial_dims, L, N, dt, speed_of_sound=c, mass=0.0
)
wave = Wave(num_spatial_dims, L, N, dt, speed_of_sound=c)

x = jnp.linspace(0, L, N, endpoint=False)
if num_spatial_dims == 1:
h0 = jnp.cos(2 * x)[None]
else:
h0 = jnp.cos(2 * x)[None, :, None] * jnp.ones((1, N, N))
v0 = jnp.zeros_like(h0)
u0 = jnp.concatenate([h0, v0], axis=0)

u_kg = u0
u_wave = u0
for _ in range(20):
u_kg = kg(u_kg)
u_wave = wave(u_wave)

assert u_kg == pytest.approx(u_wave, abs=1e-5)

def test_mass_zero_matches_wave_multi_step(self):
"""Longer evolution to catch accumulation drift."""
N, dt, c = 64, 0.005, 1.0
kg = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=0.0)
wave = Wave(1, L, N, dt, speed_of_sound=c)

x = jnp.linspace(0, L, N, endpoint=False)
h0 = (jnp.cos(x) + 0.5 * jnp.cos(3 * x))[None]
v0 = jnp.zeros_like(h0)
u0 = jnp.concatenate([h0, v0], axis=0)

u_kg = u0
u_wave = u0
for _ in range(100):
u_kg = kg(u_kg)
u_wave = wave(u_wave)

assert u_kg == pytest.approx(u_wave, abs=1e-4)


# ===========================================================================
# Analytical correctness — 1D Klein-Gordon standing mode
# ===========================================================================


class TestKleinGordonAnalytical1D:
"""For h(x,0) = cos(k0 x), v(x,0) = 0:
ω = sqrt(c²k0² + m²)
h(x,t) = cos(k0 x) cos(ω t)
v(x,t) = -ω cos(k0 x) sin(ω t)
"""

def _make_stepper_and_ic(self, k0, c=1.0, m=1.0, N=64, dt=0.01):
stepper = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=m)
x = jnp.linspace(0, L, N, endpoint=False)
h0 = jnp.cos(k0 * x)[None]
v0 = jnp.zeros_like(h0)
u0 = jnp.concatenate([h0, v0], axis=0)
omega = jnp.sqrt(c**2 * k0**2 + m**2)
return stepper, x, u0, float(omega)

@pytest.mark.parametrize("k0", [1, 2, 3, 5])
def test_single_mode(self, k0):
c, m, N, dt = 1.0, 2.0, 64, 0.01
stepper, x, u0, omega = self._make_stepper_and_ic(k0, c, m, N, dt)

n_steps = 10
u = u0
for _ in range(n_steps):
u = stepper(u)
t = n_steps * dt

h_exact = jnp.cos(k0 * x) * jnp.cos(omega * t)
v_exact = -omega * jnp.cos(k0 * x) * jnp.sin(omega * t)

assert u[0] == pytest.approx(h_exact, abs=1e-4)
assert u[1] == pytest.approx(v_exact, abs=1e-3)

def test_mass_gap(self):
"""With k0=0 (uniform mode), oscillation is at ω = m (the mass gap)."""
m, N, dt = 3.0, 32, 0.01
stepper = KleinGordon(1, L, N, dt, speed_of_sound=1.0, mass=m)

h0 = jnp.ones((1, N)) # k=0 mode
v0 = jnp.zeros_like(h0)
u0 = jnp.concatenate([h0, v0], axis=0)

n_steps = 20
u = u0
for _ in range(n_steps):
u = stepper(u)
t = n_steps * dt

h_exact = jnp.cos(m * t) * jnp.ones(N)
assert u[0] == pytest.approx(h_exact, abs=1e-4)

def test_energy_bounded(self):
"""Total energy should be conserved (bounded) over many steps."""
k0, c, m, N, dt = 3, 1.0, 2.0, 64, 0.005
stepper, x, u0, omega = self._make_stepper_and_ic(k0, c, m, N, dt)

def energy(u):
h, v = u[0], u[1]
# KE + gradient PE + mass PE
return jnp.sum(v**2 + c**2 * jnp.abs(jnp.fft.rfft(h))**2 + m**2 * h**2)

e0 = energy(u0)
u = u0
for _ in range(200):
u = stepper(u)
e_final = energy(u)

# Spectral solver should conserve energy to machine precision
assert e_final == pytest.approx(float(e0), rel=1e-3)
Loading