diff --git a/examples/experimental/01_01_shallow_water_equation.ipynb b/examples/experimental/01_01_shallow_water_equation.ipynb index b8c42df..443bde6 100644 --- a/examples/experimental/01_01_shallow_water_equation.ipynb +++ b/examples/experimental/01_01_shallow_water_equation.ipynb @@ -70,7 +70,18 @@ "\n", "Main difference: **geometry**. PDEArena production data is generated on the sphere (SpeedyWeather/Julia), while this notebook uses a periodic Cartesian beta-plane box.\n", "\n", - "## Why intersting/useful\n", + "## Parameters\n", + "\n", + "The simulator supports optional parameters in `parameters_range` (in addition to required `amp`):\n", + "\n", + "- **`amp`** (required): initial-condition amplitude scale.\n", + "- **`H`**: mean layer depth (default 1.0 if omitted).\n", + "- **`drag`**: linear drag coefficient (default 2e-3).\n", + "- **`nu`**: Laplacian viscosity (default 5e-4).\n", + "\n", + "Use any subset; omitted parameters fall back to constructor defaults. Parameter order in `parameters_range` does not matter (parsing is by name).\n", + "\n", + "## Why interesting/useful\n", "\n", "- Produces rich vortex-dominated trajectories with physically interpretable channels.\n", "- Closer to weather-like dynamics than simple advection-diffusion benchmarks.\n", @@ -87,6 +98,12 @@ "from IPython.display import HTML\n", "\n", "from autosim.experimental.simulations import ShallowWater2D\n", + "from autosim.experimental.simulations.shallow_water import (\n", + " DEFAULT_AMP_RANGE,\n", + " DEFAULT_DRAG_RANGE,\n", + " DEFAULT_H_RANGE,\n", + " DEFAULT_NU_RANGE,\n", + ")\n", "from autosim.utils import plot_spatiotemporal_video\n" ] }, @@ -97,6 +114,7 @@ "outputs": [], "source": [ "# Interactive preview config: keep CFL fixed for stability.\n", + "# Use amp only (default), or add H, drag, nu for multi-parameter studies.\n", "sim = ShallowWater2D(\n", " return_timeseries=True,\n", " nx=64,\n", @@ -108,10 +126,15 @@ " cfl=0.12,\n", " log_level=\"warning\",\n", " dtype=torch.float32,\n", - " parameters_range={\"amp\": (0.05, 0.2)},\n", + " parameters_range={\n", + " \"amp\": DEFAULT_AMP_RANGE,\n", + " \"H\": DEFAULT_H_RANGE,\n", + " \"drag\": DEFAULT_DRAG_RANGE,\n", + " \"nu\": DEFAULT_NU_RANGE,\n", + " },\n", ")\n", "\n", - "batch = sim.forward_samples_spatiotemporal(n=1)\n", + "batch = sim.forward_samples_spatiotemporal(n=3)\n", "params = batch[\"constant_scalars\"]\n", "outputs = batch[\"data\"]\n" ] @@ -124,7 +147,10 @@ "source": [ "print(\"constant_scalars shape:\", params.shape)\n", "print(\"data shape:\", outputs.shape)\n", - "print(\"sample params (trajectory 0):\", {\"amp\": float(params[0, 0])})\n" + "sample_params = {\n", + " name: float(params[0, sim.get_parameter_idx(name)]) for name in sim.param_names\n", + "}\n", + "print(\"sample params (trajectory 0):\", sample_params)\n" ] }, { @@ -135,8 +161,9 @@ "source": [ "anim = plot_spatiotemporal_video(\n", " outputs[:, :100],\n", - " batch_idx=0,\n", + " batch_idx=2,\n", " channel_names=sim.output_names,\n", + " preserve_aspect=True,\n", ")\n", "\n", "HTML(anim.to_jshtml())" @@ -155,6 +182,12 @@ "metadata": {}, "outputs": [], "source": [ + "from autosim.experimental.simulations.shallow_water import (\n", + " DEFAULT_AMP_RANGE,\n", + " DEFAULT_DRAG_RANGE,\n", + " DEFAULT_H_RANGE,\n", + " DEFAULT_NU_RANGE,\n", + ")\n", "from autosim.utils import generate_output_data\n", "\n", "sim_train = ShallowWater2D(\n", @@ -167,7 +200,12 @@ " T=10.0,\n", " dt_save=1.0,\n", " cfl=0.12,\n", - " parameters_range={\"amp\": (0.05, 0.2)},\n", + " parameters_range={\n", + " \"amp\": DEFAULT_AMP_RANGE,\n", + " \"H\": DEFAULT_H_RANGE,\n", + " \"drag\": DEFAULT_DRAG_RANGE,\n", + " \"nu\": DEFAULT_NU_RANGE,\n", + " },\n", ")\n", "\n", "outputs_data = generate_output_data(sim_train)\n" @@ -195,7 +233,7 @@ ], "metadata": { "kernelspec": { - "display_name": "autosim (3.11.14)", + "display_name": ".venv (3.11.15)", "language": "python", "name": "python3" }, @@ -209,7 +247,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.14" + "version": "3.11.15" } }, "nbformat": 4, diff --git a/src/autosim/configs/simulator/shallow_water2d.yaml b/src/autosim/configs/simulator/shallow_water2d.yaml index 61014b6..96aa265 100644 --- a/src/autosim/configs/simulator/shallow_water2d.yaml +++ b/src/autosim/configs/simulator/shallow_water2d.yaml @@ -10,7 +10,7 @@ dt_save: 0.2 skip_nt: 50 cfl: 0.12 g: 9.81 -H: 1.0 +h_mean: 1.0 nu: 0.0005 drag: 0.002 parameters_range: diff --git a/src/autosim/configs/simulator/shallow_water2d_4.yaml b/src/autosim/configs/simulator/shallow_water2d_4.yaml new file mode 100644 index 0000000..f368f9d --- /dev/null +++ b/src/autosim/configs/simulator/shallow_water2d_4.yaml @@ -0,0 +1,20 @@ +_target_: autosim.experimental.simulations.ShallowWater2D +return_timeseries: true +log_level: warning +nx: 64 +ny: 64 +Lx: 64.0 +Ly: 64.0 +T: 74.0 +dt_save: 0.2 +skip_nt: 50 +cfl: 0.12 +g: 9.81 +h_mean: 1.0 +nu: 0.0005 +drag: 0.002 +parameters_range: + amp: [0.07, 0.14] + h_mean: [0.7, 1.5] + drag: [1e-3, 4e-3] + nu: [2e-4, 8e-4] diff --git a/src/autosim/experimental/simulations/shallow_water.py b/src/autosim/experimental/simulations/shallow_water.py index 40eab8f..b494df0 100644 --- a/src/autosim/experimental/simulations/shallow_water.py +++ b/src/autosim/experimental/simulations/shallow_water.py @@ -7,9 +7,51 @@ from autosim.simulations.base import SpatioTemporalSimulator from autosim.types import TensorLike +# Default param ranges when not overridden (amp required; h_mean, drag, nu optional). +DEFAULT_AMP_RANGE: tuple[float, float] = (0.05, 0.14) +DEFAULT_H_MEAN_RANGE: tuple[float, float] = (0.7, 1.5) +DEFAULT_DRAG_RANGE: tuple[float, float] = (1e-3, 4e-3) +DEFAULT_NU_RANGE: tuple[float, float] = (2e-4, 8e-4) + +# IC and solver tuning (used in simulate_swe_2d). +U_SCALE = 0.5 # streamfunction amplitude scale for random component +JET_AMP_FRAC = 0.8 # jet speed ~ amp * JET_AMP_FRAC +PERT_LAT_FRAC = 0.65 # wave-6 perturbation center y/Ly +PERT_WIDTH_FRAC = 0.10 # Gaussian width y/Ly +WAVE_ZONAL_MODE = 6 # zonal wavenumber for mid-lat perturbation +N_JET_MODES = 4 # Fourier modes per column for jet +N_HYPERVISC = 4 # hyperviscosity exponent +K_CUT_FACTOR = 6 # k_cut = k_min * min(nx,ny) // K_CUT_FACTOR +H_MIN_CLIP = 1e-4 +H_MAX_CLIP = 100.0 +UV_ABS_CLIP = 100.0 +SATURATION_THRESHOLD = 0.01 # stop if this fraction of grid hits clip bounds +EPS = 1e-10 # small constant for safe div/norms +MIN_WAVE_SPEED_CFL = 1e-8 # floor for CFL dt; keep conservative to avoid instability + class ShallowWater2D(SpatioTemporalSimulator): - """Full 2D shallow-water simulator with prognostic [h, u, v].""" + """Full 2D shallow-water simulator with prognostic [h, u, v]. + + Parameters + ---------- + parameters_range : dict, optional + Input parameter (min, max) ranges. Supported keys: + - ``amp`` (required): initial-condition amplitude scale. + - ``h_mean``: mean layer depth (scalar) around which spatial + variations are generated (default 1.0 if omitted). + - ``drag``: linear drag coefficient (default 2e-3). + - ``nu``: Laplacian viscosity (default 5e-4). + If None, uses ``{"amp": (0.05, 0.14)}`` only. + output_names, return_timeseries, log_level + Passed to base. Default outputs: ["h", "u", "v"]. + nx, ny, Lx, Ly, T, dt_save, skip_nt, cfl + Grid, domain, time and CFL settings. + g, h_mean, nu, drag + Physics constants (used when not in parameters_range). + dtype + torch.float32 or torch.float64. + """ def __init__( self, @@ -26,13 +68,13 @@ def __init__( skip_nt: int = 0, cfl: float = 0.12, g: float = 9.81, - H: float = 1.0, + h_mean: float = 1.0, nu: float = 5e-4, drag: float = 2e-3, dtype: torch.dtype = torch.float64, ) -> None: if parameters_range is None: - parameters_range = {"amp": (0.05, 0.14)} + parameters_range = {"amp": DEFAULT_AMP_RANGE} if output_names is None: output_names = ["h", "u", "v"] @@ -50,14 +92,38 @@ def __init__( self.skip_nt = skip_nt self.cfl = cfl self.g = g - self.H = H + self.h_mean = h_mean self.nu = nu self.drag = drag self.dtype = dtype def _forward(self, x: TensorLike) -> TensorLike: - assert x.shape[0] == 1, "Simulator._forward expects a single input" - amp = float(x[0, 0].item()) + if x.shape[0] != 1: + msg = "Simulator._forward expects a single input (batch size 1)" + raise ValueError(msg) + if x.shape[1] != self.in_dim: + msg = ( + f"Input dim {x.shape[1]} does not match " + f"parameters_range length {self.in_dim}" + ) + raise ValueError(msg) + # Parse by name so parameter order is irrelevant and optional params clear. + amp = float(x[0, self.get_parameter_idx("amp")].item()) + h_mean = ( + float(x[0, self.get_parameter_idx("h_mean")].item()) + if "h_mean" in self.param_names + else self.h_mean + ) + drag = ( + float(x[0, self.get_parameter_idx("drag")].item()) + if "drag" in self.param_names + else self.drag + ) + nu = ( + float(x[0, self.get_parameter_idx("nu")].item()) + if "nu" in self.param_names + else self.nu + ) y = simulate_swe_2d( amp=amp, @@ -71,9 +137,9 @@ def _forward(self, x: TensorLike) -> TensorLike: skip_nt=self.skip_nt, cfl=self.cfl, g=self.g, - H=self.H, - nu=self.nu, - drag=self.drag, + h_mean=h_mean, + nu=nu, + drag=drag, dtype=self.dtype, ) return y.flatten().unsqueeze(0) @@ -120,7 +186,7 @@ def simulate_swe_2d( # noqa: PLR0912, PLR0915 dt_save: float, cfl: float, g: float, - H: float, + h_mean: float, nu: float, drag: float, dtype: torch.dtype = torch.float64, @@ -142,7 +208,7 @@ def simulate_swe_2d( # noqa: PLR0912, PLR0915 dx = Lx / nx dy = Ly / ny - c = math.sqrt(g * H) + c = math.sqrt(g * h_mean) f0 = c / 8.0 beta = 0.5 * f0 / Ly f_grid = f0 + beta * (Y - 0.5 * Ly) @@ -157,10 +223,9 @@ def simulate_swe_2d( # noqa: PLR0912, PLR0915 # Hyperviscosity integrating-factor operator (same approach as barotropic solver). # Damps grid-scale modes in ~1 time unit; leaves large-scale vortices untouched. - n_hyp = 4 k_max = math.pi * max(nx / Lx, ny / Ly) - nu_h = 1.0 / k_max ** (2 * n_hyp) - hyp_op = -nu_h * K2**n_hyp + nu_h = 1.0 / k_max ** (2 * N_HYPERVISC) + hyp_op = -nu_h * K2**N_HYPERVISC def to_spec(field: torch.Tensor) -> torch.Tensor: return torch.fft.rfft2(field) @@ -182,12 +247,12 @@ def laplacian(field: torch.Tensor) -> torch.Tensor: # ------------------------------------------------------------------ # # Strategy: specify vorticity ζ (random large-scale + jet shear + # wave-6 perturbation), solve ∇²ψ = ζ spectrally, then derive - # u = -∂ψ/∂y, v = ∂ψ/∂x, h = H + (f0/g)·ψ + # u = -∂ψ/∂y, v = ∂ψ/∂x, h = h_mean + (f0/g)·ψ # This guarantees exact geostrophic balance at t=0 so no spurious # gravity-wave transients are excited. k_min = 2.0 * math.pi / max(Lx, Ly) - k_cut = k_min * min(nx, ny) // 6 + k_cut = k_min * min(nx, ny) // K_CUT_FACTOR # Component 1: random large-scale streamfunction (k^{-2} → E(k)~k^{-3}) rand_re = torch.randn(nx, ny // 2 + 1, dtype=dtype) @@ -200,39 +265,35 @@ def laplacian(field: torch.Tensor) -> torch.Tensor: ) psi_hat_rand[0, 0] = 0.0 psi_rand_phys = to_phys(psi_hat_rand) - U_scale = 0.5 - psi_norm = amp * U_scale * min(Lx, Ly) / (float(psi_rand_phys.std()) + 1e-10) + psi_norm = amp * U_SCALE * min(Lx, Ly) / (float(psi_rand_phys.std()) + EPS) psi_hat_rand = psi_hat_rand * psi_norm zeta_random = to_phys(-K2 * psi_hat_rand) # Component 2: per-column independent random zonal jet (PDEArena :random2 style) # Each longitude column i gets its own independent random Fourier coefficients # in y — matching PDEArena's truly per-column i.i.d. wind profiles. - n_modes = 4 - coeff = torch.randn(nx, n_modes, dtype=dtype) # [nx, n_modes], i.i.d. per column + coeff = torch.randn(nx, N_JET_MODES, dtype=dtype) # i.i.d. per column y_frac = Y / Ly # [nx, ny], values in [0, 1] u_jet_field = torch.stack( [ coeff[:, m].unsqueeze(1) * torch.sin((m + 1) * math.pi * y_frac) - for m in range(n_modes) + for m in range(N_JET_MODES) ], dim=0, ).sum(dim=0) # [nx, ny] - # Normalise so |u_jet| ~ amp * 0.8 regardless of random draw - jet_std = float(u_jet_field.std()) + 1e-10 - u_jet_field = u_jet_field * (amp * 0.8 / jet_std) + jet_std = float(u_jet_field.std()) + EPS + u_jet_field = u_jet_field * (amp * JET_AMP_FRAC / jet_std) zeta_jet = to_phys(-1j * Ky * to_spec(u_jet_field)) # Component 3: wave-6 Gaussian perturbation at mid-latitude - zeta_jet_scale = float(zeta_jet.std()) + 1e-10 - A_pert = max(amp * f0 * 0.8, zeta_jet_scale * 0.25) - m_wave = 6 - y_center = 0.65 * Ly - y_width = 0.10 * Ly + zeta_jet_scale = float(zeta_jet.std()) + EPS + A_pert = max(amp * f0 * JET_AMP_FRAC, zeta_jet_scale * 0.25) + y_center = PERT_LAT_FRAC * Ly + y_width = PERT_WIDTH_FRAC * Ly pert_phase = float(torch.rand(1)) * 2.0 * math.pi zeta_pert = ( A_pert - * torch.cos(m_wave * 2.0 * math.pi * X / Lx + pert_phase) + * torch.cos(WAVE_ZONAL_MODE * 2.0 * math.pi * X / Lx + pert_phase) * torch.exp(-0.5 * ((Y - y_center) / y_width) ** 2) ) @@ -247,12 +308,12 @@ def laplacian(field: torch.Tensor) -> torch.Tensor: u0 = to_phys(-1j * Ky * psi_h) # u = -∂ψ/∂y v0 = to_phys(1j * Kx * psi_h) # v = ∂ψ/∂x - h0 = (H + (f0 / g) * psi0).clamp(min=0.5 * H) # geostrophic balance + h0 = (h_mean + (f0 / g) * psi0).clamp(min=0.5 * h_mean) # geostrophic balance def rhs( h: torch.Tensor, u: torch.Tensor, v: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - h_safe = h.clamp(min=1e-4) + h_safe = h.clamp(min=H_MIN_CLIP) # Reuse spectra per field to avoid repeated FFTs in each RHS evaluation. u_h = to_spec(u) @@ -286,33 +347,48 @@ def rhs( dhdt = -div_hu return dhdt, dudt, dvdt - def output(h: torch.Tensor, u: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - h_out = torch.nan_to_num(h, nan=H, posinf=100.0, neginf=1e-4).clamp( - min=1e-4, max=100.0 + def rk4_step( + h: torch.Tensor, u: torch.Tensor, v: torch.Tensor, dt: float + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + k1_h, k1_u, k1_v = rhs(h, u, v) + k2_h, k2_u, k2_v = rhs( + h + 0.5 * dt * k1_h, u + 0.5 * dt * k1_u, v + 0.5 * dt * k1_v ) - u_out = torch.nan_to_num(u, nan=0.0, posinf=100.0, neginf=-100.0).clamp( - min=-100.0, max=100.0 + k3_h, k3_u, k3_v = rhs( + h + 0.5 * dt * k2_h, u + 0.5 * dt * k2_u, v + 0.5 * dt * k2_v ) - v_out = torch.nan_to_num(v, nan=0.0, posinf=100.0, neginf=-100.0).clamp( - min=-100.0, max=100.0 + k4_h, k4_u, k4_v = rhs(h + dt * k3_h, u + dt * k3_u, v + dt * k3_v) + return ( + h + (dt / 6.0) * (k1_h + 2.0 * k2_h + 2.0 * k3_h + k4_h), + u + (dt / 6.0) * (k1_u + 2.0 * k2_u + 2.0 * k3_u + k4_u), + v + (dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v), ) + + def output(h: torch.Tensor, u: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + h_out = torch.nan_to_num( + h, + nan=h_mean, + posinf=H_MAX_CLIP, + neginf=H_MIN_CLIP, + ).clamp(min=H_MIN_CLIP, max=H_MAX_CLIP) + u_out = torch.nan_to_num( + u, nan=0.0, posinf=UV_ABS_CLIP, neginf=-UV_ABS_CLIP + ).clamp(min=-UV_ABS_CLIP, max=UV_ABS_CLIP) + v_out = torch.nan_to_num( + v, nan=0.0, posinf=UV_ABS_CLIP, neginf=-UV_ABS_CLIP + ).clamp(min=-UV_ABS_CLIP, max=UV_ABS_CLIP) return torch.stack([h_out.float(), u_out.float(), v_out.float()], dim=-1) h = h0 u = u0 v = v0 - h_min_bound = 1e-4 - h_max_bound = 100.0 - uv_abs_bound = 100.0 - saturation_frac_threshold = 0.01 - def _saturation_fraction( h_curr: torch.Tensor, u_curr: torch.Tensor, v_curr: torch.Tensor ) -> float: - h_sat = ((h_curr <= h_min_bound) | (h_curr >= h_max_bound)).float().mean() - u_sat = (u_curr.abs() >= uv_abs_bound).float().mean() - v_sat = (v_curr.abs() >= uv_abs_bound).float().mean() + h_sat = ((h_curr <= H_MIN_CLIP) | (h_curr >= H_MAX_CLIP)).float().mean() + u_sat = (u_curr.abs() >= UV_ABS_CLIP).float().mean() + v_sat = (v_curr.abs() >= UV_ABS_CLIP).float().mean() return float(torch.maximum(torch.maximum(h_sat, u_sat), v_sat).item()) history: list[torch.Tensor] = [] @@ -330,7 +406,7 @@ def _saturation_fraction( ): failure_reason = "non-finite state encountered" break - if _saturation_fraction(h, u, v) >= saturation_frac_threshold: + if _saturation_fraction(h, u, v) >= SATURATION_THRESHOLD: failure_reason = "state saturated at clipping bounds" break @@ -342,10 +418,10 @@ def _saturation_fraction( if t >= T - 1e-10: break - c_now = torch.sqrt(g * h.clamp(min=1e-4)) + c_now = torch.sqrt(g * h.clamp(min=H_MIN_CLIP)) speed_x = (u.abs() + c_now).max().item() speed_y = (v.abs() + c_now).max().item() - max_speed = max(speed_x, speed_y, 1e-8) + max_speed = max(speed_x, speed_y, MIN_WAVE_SPEED_CFL) if not math.isfinite(max_speed): failure_reason = "non-finite wave speed" break @@ -353,43 +429,24 @@ def _saturation_fraction( dt = cfl * min(dx, dy) / max_speed dt = min(dt, next_save - t, T - t) dt = min(dt, 0.5 * dt_save) - dt = max(dt, 1e-10) - - k1_h, k1_u, k1_v = rhs(h, u, v) - k2_h, k2_u, k2_v = rhs( - h + 0.5 * dt * k1_h, - u + 0.5 * dt * k1_u, - v + 0.5 * dt * k1_v, - ) - k3_h, k3_u, k3_v = rhs( - h + 0.5 * dt * k2_h, - u + 0.5 * dt * k2_u, - v + 0.5 * dt * k2_v, - ) - k4_h, k4_u, k4_v = rhs( - h + dt * k3_h, - u + dt * k3_u, - v + dt * k3_v, - ) + dt = max(dt, EPS) - h = h + (dt / 6.0) * (k1_h + 2.0 * k2_h + 2.0 * k3_h + k4_h) - u = u + (dt / 6.0) * (k1_u + 2.0 * k2_u + 2.0 * k3_u + k4_u) - v = v + (dt / 6.0) * (k1_v + 2.0 * k2_v + 2.0 * k3_v + k4_v) + h, u, v = rk4_step(h, u, v, dt) # Apply hyperviscosity integrating factor to all fields (spectral filter). hyp_factor = torch.exp(hyp_op * dt) # real, shape [nx, ny//2+1] u = to_phys(to_spec(u) * hyp_factor) v = to_phys(to_spec(v) * hyp_factor) - h_mean = h.mean() - h_anom = h - h_mean + h_field_mean = h.mean() + h_anom = h - h_field_mean h_anom = to_phys(to_spec(h_anom) * hyp_factor) - h = (h_mean + h_anom).clamp(min=1e-4) + h = (h_field_mean + h_anom).clamp(min=H_MIN_CLIP) if ( torch.isfinite(h).all() and torch.isfinite(u).all() and torch.isfinite(v).all() ): - if _saturation_fraction(h, u, v) >= saturation_frac_threshold: + if _saturation_fraction(h, u, v) >= SATURATION_THRESHOLD: failure_reason = "state saturated at clipping bounds after step" break last_valid = output(h, u, v) diff --git a/tests/simulations/test_shallow_water.py b/tests/simulations/test_shallow_water.py index e9f19bb..26d8cb9 100644 --- a/tests/simulations/test_shallow_water.py +++ b/tests/simulations/test_shallow_water.py @@ -63,7 +63,7 @@ def test_full_swe_skip_nt_too_large_raises() -> None: dt_save=1.0, cfl=0.12, g=9.81, - H=1.0, + h_mean=1.0, nu=5e-4, drag=2e-3, skip_nt=2,