Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion models/src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,69 @@ def predict_step(
...


__all__ = ["ModelInterface"]
@runtime_checkable
class DiffusionModelInterface(ModelInterface, Protocol):
"""Interface for models that support diffusion tasks."""

def get_diffusion_parameters(self) -> tuple[float, float, float]:
"""Return ``(sigma_max, sigma_min, sigma_data)`` for diffusion training."""
...

def forward_with_preconditioning(
self,
x: dict[str, torch.Tensor],
y_noised: dict[str, torch.Tensor],
sigma: dict[str, torch.Tensor],
**kwargs,
) -> dict[str, torch.Tensor]:
"""Run the diffusion forward pass with model-specific preconditioning."""
...

def apply_imputer_inverse(self, dataset_name: str, x: torch.Tensor) -> torch.Tensor:
"""Map output-space tensors back through any inverse imputation logic."""
...

def apply_reference_state_truncation(
self,
x: dict[str, torch.Tensor],
grid_shard_shapes,
model_comm_group: Optional[ProcessGroup] = None,
) -> dict[str, torch.Tensor]:
"""Prepare reference states used by diffusion tasks."""
...


@runtime_checkable
class DiffusionTendencyModelInterface(DiffusionModelInterface, Protocol):
"""Interface for diffusion models that predict tendencies."""

def get_tendency_processors(self, dataset_name: str) -> tuple[object, object]:
"""Return the pre/post tendency processors for one dataset."""
...

def compute_tendency_step(
self,
dataset_name: str,
y_step: torch.Tensor,
x_ref_step: torch.Tensor,
tendency_pre_processor: object,
) -> torch.Tensor:
"""Convert one output step into a tendency target."""
...

def add_tendency_to_state_step(
self,
dataset_name: str,
x_ref_step: torch.Tensor,
tendency_step: torch.Tensor,
tendency_post_processor: object,
) -> torch.Tensor:
"""Reconstruct one state step from a tendency prediction."""
...


__all__ = [
"ModelInterface",
"DiffusionModelInterface",
"DiffusionTendencyModelInterface",
]
4 changes: 4 additions & 0 deletions models/src/anemoi/models/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from .anemoi_model import AnemoiDiffusionModel
from .anemoi_model import AnemoiDiffusionTendencyModel
from .anemoi_model import AnemoiModel
from .base import BaseGraphModel
from .encoder_processor_decoder import AnemoiModelEncProcDec
Expand All @@ -22,6 +24,8 @@

__all__ = [
"AnemoiModel",
"AnemoiDiffusionModel",
"AnemoiDiffusionTendencyModel",
"BaseGraphModel",
"NaiveModel",
"AnemoiModelEncProcDec",
Expand Down
68 changes: 67 additions & 1 deletion models/src/anemoi/models/models/anemoi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self,
*,
model_config: DotDict,
graph_data: dict[str, HeteroData],
graph_data: HeteroData,
statistics: dict,
data_indices: dict,
metadata: dict,
Expand Down Expand Up @@ -89,3 +89,69 @@ def predict_step(
y = self.forward(x, model_comm_group=model_comm_group, **kwargs)
y = self.post_process(y)
return y


class AnemoiDiffusionModel(AnemoiModel):
"""Anemoi wrapper for diffusion-capable backbones."""

def get_diffusion_parameters(self) -> tuple[float, float, float]:
return self.backbone.sigma_max, self.backbone.sigma_min, self.backbone.sigma_data

def forward_with_preconditioning(
self,
x: dict[str, torch.Tensor],
y_noised: dict[str, torch.Tensor],
sigma: dict[str, torch.Tensor],
**kwargs,
) -> dict[str, torch.Tensor]:
return self.backbone.fwd_with_preconditioning(x, y_noised, sigma, **kwargs)

def apply_imputer_inverse(self, dataset_name: str, x: torch.Tensor) -> torch.Tensor:
return self.backbone._apply_imputer_inverse(self.post_processors, dataset_name, x)

def apply_reference_state_truncation(
self,
x: dict[str, torch.Tensor],
grid_shard_shapes,
model_comm_group: Optional[ProcessGroup] = None,
) -> dict[str, torch.Tensor]:
return self.backbone.apply_reference_state_truncation(x, grid_shard_shapes, model_comm_group)


class AnemoiDiffusionTendencyModel(AnemoiDiffusionModel):
"""Anemoi wrapper for diffusion backbones that predict tendencies."""

def get_tendency_processors(self, dataset_name: str) -> tuple[object, object]:
return self.pre_processors_tendencies[dataset_name], self.post_processors_tendencies[dataset_name]

def compute_tendency_step(
self,
dataset_name: str,
y_step: torch.Tensor,
x_ref_step: torch.Tensor,
tendency_pre_processor: object,
) -> torch.Tensor:
return self.backbone.compute_tendency(
{dataset_name: y_step},
{dataset_name: x_ref_step},
{dataset_name: self.pre_processors[dataset_name]},
{dataset_name: tendency_pre_processor},
input_post_processor={dataset_name: self.post_processors[dataset_name]},
skip_imputation=True,
)[dataset_name]

def add_tendency_to_state_step(
self,
dataset_name: str,
x_ref_step: torch.Tensor,
tendency_step: torch.Tensor,
tendency_post_processor: object,
) -> torch.Tensor:
return self.backbone.add_tendency_to_state(
{dataset_name: x_ref_step},
{dataset_name: tendency_step},
{dataset_name: self.post_processors[dataset_name]},
{dataset_name: tendency_post_processor},
output_pre_processor={dataset_name: self.pre_processors[dataset_name]},
skip_imputation=True,
)[dataset_name]
15 changes: 13 additions & 2 deletions models/src/anemoi/models/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@
LOGGER = logging.getLogger(__name__)


def _get_backbone_config(model_config: DotDict) -> DotDict:
"""Return the backbone config."""
backbone_config = model_config.model.get("backbone")
if backbone_config is not None:
return DotDict(backbone_config)

msg = "Expected model config to define `model.backbone`."
raise KeyError(msg)


class BaseGraphModel(nn.Module):
"""Message passing graph neural network."""

Expand Down Expand Up @@ -63,12 +73,13 @@ def __init__(
self.dataset_names = list(data_indices.keys())

model_config = DotDict(model_config)
self._graph_name_hidden = model_config.model.model.hidden_nodes_name
backbone_config = _get_backbone_config(model_config)
self._graph_name_hidden = backbone_config.hidden_nodes_name

self.n_step_input = model_config.training.multistep_input
self.n_step_output = model_config.training.multistep_output
self.num_channels = model_config.model.num_channels
self.latent_skip = model_config.model.model.latent_skip
self.latent_skip = backbone_config.latent_skip

trainable_parameters = broadcast_config_keys(
model_config.model.trainable_parameters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(

model_config_local = DotDict(model_config)

diffusion_config = model_config_local.model.nn.diffusion
diffusion_config = model_config_local.model.backbone.diffusion
self.noise_channels = diffusion_config.noise_channels
self.noise_cond_dim = diffusion_config.noise_cond_dim
self.sigma_data = diffusion_config.sigma_data
Expand Down
4 changes: 3 additions & 1 deletion models/src/anemoi/models/models/hierarchical_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from anemoi.models.layers.graph import NamedNodesAttributes
from anemoi.models.layers.graph_provider import create_graph_provider
from anemoi.models.models import AnemoiModelAutoEncoder
from anemoi.models.models.base import _get_backbone_config
from anemoi.utils.config import DotDict


Expand Down Expand Up @@ -53,7 +54,8 @@ def __init__(
self.statistics = statistics

model_config = DotDict(model_config)
self._graph_name_hidden = model_config.model.model.hidden_nodes_name
backbone_config = _get_backbone_config(model_config)
self._graph_name_hidden = backbone_config.hidden_nodes_name

self.n_step_input = model_config.training.multistep_input
self.n_step_output = model_config.training.multistep_output
Expand Down
6 changes: 4 additions & 2 deletions models/src/anemoi/models/models/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def forward(self, x: dict[str, Tensor], **_) -> dict[str, Tensor]:
out = {}
for name, x_ds in x.items():
bs, t, ens, grid, nv = x_ds.shape
x_flat = x_ds.reshape(bs * grid, t * nv)
x_flat = x_ds.permute(0, 2, 3, 1, 4).reshape(bs * ens * grid, t * nv)
y_flat = self.linear(x_flat)
out[name] = y_flat.to(x_ds.dtype).reshape(bs, self.n_step_output, ens, grid, self._n_output)
out[name] = (
y_flat.to(x_ds.dtype).reshape(bs, ens, grid, self.n_step_output, self._n_output).permute(0, 3, 1, 2, 4)
)
return out

def predict_step(
Expand Down
32 changes: 27 additions & 5 deletions models/tests/models/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# (C) Copyright 2024 Anemoi contributors.
# (C) Copyright 2026 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
Expand All @@ -7,10 +7,32 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import torch
from omegaconf import OmegaConf

def test_models():
pass
from anemoi.models.models.base import _get_backbone_config
from anemoi.models.models.naive import NaiveModel


if __name__ == "__main__":
test_models()
def test_get_backbone_config_reads_backbone_layout() -> None:
config = OmegaConf.create({"model": {"backbone": {"hidden_nodes_name": "hidden", "latent_skip": True}}})

assert _get_backbone_config(config).hidden_nodes_name == "hidden"
assert _get_backbone_config(config).latent_skip is True


def test_get_backbone_config_requires_backbone() -> None:
config = OmegaConf.create({"model": {}})

with pytest.raises(KeyError, match="model.backbone"):
_get_backbone_config(config)


def test_naive_model_preserves_ensemble_dimension() -> None:
model = NaiveModel(n_input=2, n_output=3, n_step_input=2, n_step_output=1)
x = {"data": torch.randn(4, 2, 5, 7, 2)}

y = model(x)

assert y["data"].shape == (4, 1, 5, 7, 3)
Loading
Loading