diff --git a/models/src/anemoi/models/interface/__init__.py b/models/src/anemoi/models/interface/__init__.py index 516f8b248f..fcf0fe8032 100644 --- a/models/src/anemoi/models/interface/__init__.py +++ b/models/src/anemoi/models/interface/__init__.py @@ -46,4 +46,69 @@ def predict_step( ... -__all__ = ["ModelInterface"] +@runtime_checkable +class DiffusionModelInterface(ModelInterface, Protocol): + """Interface for models that support diffusion tasks.""" + + def get_diffusion_parameters(self) -> tuple[float, float, float]: + """Return ``(sigma_max, sigma_min, sigma_data)`` for diffusion training.""" + ... + + def forward_with_preconditioning( + self, + x: dict[str, torch.Tensor], + y_noised: dict[str, torch.Tensor], + sigma: dict[str, torch.Tensor], + **kwargs, + ) -> dict[str, torch.Tensor]: + """Run the diffusion forward pass with model-specific preconditioning.""" + ... + + def apply_imputer_inverse(self, dataset_name: str, x: torch.Tensor) -> torch.Tensor: + """Map output-space tensors back through any inverse imputation logic.""" + ... + + def apply_reference_state_truncation( + self, + x: dict[str, torch.Tensor], + grid_shard_shapes, + model_comm_group: Optional[ProcessGroup] = None, + ) -> dict[str, torch.Tensor]: + """Prepare reference states used by diffusion tasks.""" + ... + + +@runtime_checkable +class DiffusionTendencyModelInterface(DiffusionModelInterface, Protocol): + """Interface for diffusion models that predict tendencies.""" + + def get_tendency_processors(self, dataset_name: str) -> tuple[object, object]: + """Return the pre/post tendency processors for one dataset.""" + ... + + def compute_tendency_step( + self, + dataset_name: str, + y_step: torch.Tensor, + x_ref_step: torch.Tensor, + tendency_pre_processor: object, + ) -> torch.Tensor: + """Convert one output step into a tendency target.""" + ... + + def add_tendency_to_state_step( + self, + dataset_name: str, + x_ref_step: torch.Tensor, + tendency_step: torch.Tensor, + tendency_post_processor: object, + ) -> torch.Tensor: + """Reconstruct one state step from a tendency prediction.""" + ... + + +__all__ = [ + "ModelInterface", + "DiffusionModelInterface", + "DiffusionTendencyModelInterface", +] diff --git a/models/src/anemoi/models/models/__init__.py b/models/src/anemoi/models/models/__init__.py index ce0f0fc2a7..8e0b3df79a 100644 --- a/models/src/anemoi/models/models/__init__.py +++ b/models/src/anemoi/models/models/__init__.py @@ -7,6 +7,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from .anemoi_model import AnemoiDiffusionModel +from .anemoi_model import AnemoiDiffusionTendencyModel from .anemoi_model import AnemoiModel from .base import BaseGraphModel from .encoder_processor_decoder import AnemoiModelEncProcDec @@ -22,6 +24,8 @@ __all__ = [ "AnemoiModel", + "AnemoiDiffusionModel", + "AnemoiDiffusionTendencyModel", "BaseGraphModel", "NaiveModel", "AnemoiModelEncProcDec", diff --git a/models/src/anemoi/models/models/anemoi_model.py b/models/src/anemoi/models/models/anemoi_model.py index 32f4b8289f..54410c1527 100644 --- a/models/src/anemoi/models/models/anemoi_model.py +++ b/models/src/anemoi/models/models/anemoi_model.py @@ -25,7 +25,7 @@ def __init__( self, *, model_config: DotDict, - graph_data: dict[str, HeteroData], + graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict, @@ -89,3 +89,69 @@ def predict_step( y = self.forward(x, model_comm_group=model_comm_group, **kwargs) y = self.post_process(y) return y + + +class AnemoiDiffusionModel(AnemoiModel): + """Anemoi wrapper for diffusion-capable backbones.""" + + def get_diffusion_parameters(self) -> tuple[float, float, float]: + return self.backbone.sigma_max, self.backbone.sigma_min, self.backbone.sigma_data + + def forward_with_preconditioning( + self, + x: dict[str, torch.Tensor], + y_noised: dict[str, torch.Tensor], + sigma: dict[str, torch.Tensor], + **kwargs, + ) -> dict[str, torch.Tensor]: + return self.backbone.fwd_with_preconditioning(x, y_noised, sigma, **kwargs) + + def apply_imputer_inverse(self, dataset_name: str, x: torch.Tensor) -> torch.Tensor: + return self.backbone._apply_imputer_inverse(self.post_processors, dataset_name, x) + + def apply_reference_state_truncation( + self, + x: dict[str, torch.Tensor], + grid_shard_shapes, + model_comm_group: Optional[ProcessGroup] = None, + ) -> dict[str, torch.Tensor]: + return self.backbone.apply_reference_state_truncation(x, grid_shard_shapes, model_comm_group) + + +class AnemoiDiffusionTendencyModel(AnemoiDiffusionModel): + """Anemoi wrapper for diffusion backbones that predict tendencies.""" + + def get_tendency_processors(self, dataset_name: str) -> tuple[object, object]: + return self.pre_processors_tendencies[dataset_name], self.post_processors_tendencies[dataset_name] + + def compute_tendency_step( + self, + dataset_name: str, + y_step: torch.Tensor, + x_ref_step: torch.Tensor, + tendency_pre_processor: object, + ) -> torch.Tensor: + return self.backbone.compute_tendency( + {dataset_name: y_step}, + {dataset_name: x_ref_step}, + {dataset_name: self.pre_processors[dataset_name]}, + {dataset_name: tendency_pre_processor}, + input_post_processor={dataset_name: self.post_processors[dataset_name]}, + skip_imputation=True, + )[dataset_name] + + def add_tendency_to_state_step( + self, + dataset_name: str, + x_ref_step: torch.Tensor, + tendency_step: torch.Tensor, + tendency_post_processor: object, + ) -> torch.Tensor: + return self.backbone.add_tendency_to_state( + {dataset_name: x_ref_step}, + {dataset_name: tendency_step}, + {dataset_name: self.post_processors[dataset_name]}, + {dataset_name: tendency_post_processor}, + output_pre_processor={dataset_name: self.pre_processors[dataset_name]}, + skip_imputation=True, + )[dataset_name] diff --git a/models/src/anemoi/models/models/base.py b/models/src/anemoi/models/models/base.py index 3de4d9430c..bd72edc0c9 100644 --- a/models/src/anemoi/models/models/base.py +++ b/models/src/anemoi/models/models/base.py @@ -31,6 +31,16 @@ LOGGER = logging.getLogger(__name__) +def _get_backbone_config(model_config: DotDict) -> DotDict: + """Return the backbone config.""" + backbone_config = model_config.model.get("backbone") + if backbone_config is not None: + return DotDict(backbone_config) + + msg = "Expected model config to define `model.backbone`." + raise KeyError(msg) + + class BaseGraphModel(nn.Module): """Message passing graph neural network.""" @@ -63,12 +73,13 @@ def __init__( self.dataset_names = list(data_indices.keys()) model_config = DotDict(model_config) - self._graph_name_hidden = model_config.model.model.hidden_nodes_name + backbone_config = _get_backbone_config(model_config) + self._graph_name_hidden = backbone_config.hidden_nodes_name self.n_step_input = model_config.training.multistep_input self.n_step_output = model_config.training.multistep_output self.num_channels = model_config.model.num_channels - self.latent_skip = model_config.model.model.latent_skip + self.latent_skip = backbone_config.latent_skip trainable_parameters = broadcast_config_keys( model_config.model.trainable_parameters, diff --git a/models/src/anemoi/models/models/diffusion_encoder_processor_decoder.py b/models/src/anemoi/models/models/diffusion_encoder_processor_decoder.py index a1629abaf4..58404bea47 100644 --- a/models/src/anemoi/models/models/diffusion_encoder_processor_decoder.py +++ b/models/src/anemoi/models/models/diffusion_encoder_processor_decoder.py @@ -49,7 +49,7 @@ def __init__( model_config_local = DotDict(model_config) - diffusion_config = model_config_local.model.nn.diffusion + diffusion_config = model_config_local.model.backbone.diffusion self.noise_channels = diffusion_config.noise_channels self.noise_cond_dim = diffusion_config.noise_cond_dim self.sigma_data = diffusion_config.sigma_data diff --git a/models/src/anemoi/models/models/hierarchical_autoencoder.py b/models/src/anemoi/models/models/hierarchical_autoencoder.py index 30b5130e9c..25af839a09 100644 --- a/models/src/anemoi/models/models/hierarchical_autoencoder.py +++ b/models/src/anemoi/models/models/hierarchical_autoencoder.py @@ -21,6 +21,7 @@ from anemoi.models.layers.graph import NamedNodesAttributes from anemoi.models.layers.graph_provider import create_graph_provider from anemoi.models.models import AnemoiModelAutoEncoder +from anemoi.models.models.base import _get_backbone_config from anemoi.utils.config import DotDict @@ -53,7 +54,8 @@ def __init__( self.statistics = statistics model_config = DotDict(model_config) - self._graph_name_hidden = model_config.model.model.hidden_nodes_name + backbone_config = _get_backbone_config(model_config) + self._graph_name_hidden = backbone_config.hidden_nodes_name self.n_step_input = model_config.training.multistep_input self.n_step_output = model_config.training.multistep_output diff --git a/models/src/anemoi/models/models/naive.py b/models/src/anemoi/models/models/naive.py index ee0a66d7f9..14307a30e3 100644 --- a/models/src/anemoi/models/models/naive.py +++ b/models/src/anemoi/models/models/naive.py @@ -56,9 +56,11 @@ def forward(self, x: dict[str, Tensor], **_) -> dict[str, Tensor]: out = {} for name, x_ds in x.items(): bs, t, ens, grid, nv = x_ds.shape - x_flat = x_ds.reshape(bs * grid, t * nv) + x_flat = x_ds.permute(0, 2, 3, 1, 4).reshape(bs * ens * grid, t * nv) y_flat = self.linear(x_flat) - out[name] = y_flat.to(x_ds.dtype).reshape(bs, self.n_step_output, ens, grid, self._n_output) + out[name] = ( + y_flat.to(x_ds.dtype).reshape(bs, ens, grid, self.n_step_output, self._n_output).permute(0, 3, 1, 2, 4) + ) return out def predict_step( diff --git a/models/tests/models/test_models.py b/models/tests/models/test_models.py index f005f3ff68..1dbadcc98b 100644 --- a/models/tests/models/test_models.py +++ b/models/tests/models/test_models.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2026 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,10 +7,32 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import pytest +import torch +from omegaconf import OmegaConf -def test_models(): - pass +from anemoi.models.models.base import _get_backbone_config +from anemoi.models.models.naive import NaiveModel -if __name__ == "__main__": - test_models() +def test_get_backbone_config_reads_backbone_layout() -> None: + config = OmegaConf.create({"model": {"backbone": {"hidden_nodes_name": "hidden", "latent_skip": True}}}) + + assert _get_backbone_config(config).hidden_nodes_name == "hidden" + assert _get_backbone_config(config).latent_skip is True + + +def test_get_backbone_config_requires_backbone() -> None: + config = OmegaConf.create({"model": {}}) + + with pytest.raises(KeyError, match="model.backbone"): + _get_backbone_config(config) + + +def test_naive_model_preserves_ensemble_dimension() -> None: + model = NaiveModel(n_input=2, n_output=3, n_step_input=2, n_step_output=1) + x = {"data": torch.randn(4, 2, 5, 7, 2)} + + y = model(x) + + assert y["data"].shape == (4, 1, 5, 7, 3) diff --git a/training/src/anemoi/training/builder.py b/training/src/anemoi/training/builder.py index b49e3f3e27..e88c1b43a2 100644 --- a/training/src/anemoi/training/builder.py +++ b/training/src/anemoi/training/builder.py @@ -7,30 +7,25 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -"""Factory function for building ModelInterface via Hydra instantiate.""" +"""Factory functions for building ModelInterface instances via Hydra.""" from __future__ import annotations -import datetime -import logging -import uuid as _uuid_module -from pathlib import Path +from typing import TYPE_CHECKING import torch from hydra.utils import instantiate -from omegaconf import DictConfig from omegaconf import OmegaConf -from anemoi.models.interface import ModelInterface -from anemoi.models.models import AnemoiModel from anemoi.models.preprocessing import Processors from anemoi.models.preprocessing import StepwiseProcessors from anemoi.models.utils.config import get_multiple_datasets_config -from anemoi.training.data.datamodule import AnemoiDatasetsDataModule -from anemoi.training.utils.jsonify import map_config_to_primitives -from anemoi.utils.provenance import gather_provenance_info -LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from anemoi.models.interface import ModelInterface + from anemoi.training.config_bundle import ModelConfigBundle + from anemoi.training.runtime import ModelRuntimeArtifacts + from anemoi.utils.config import DotDict # --------------------------------------------------------------------------- @@ -116,7 +111,7 @@ def _build_processors( data_config[dataset_name].processors, statistics[dataset_name], data_indices[dataset_name], - statistics_tendencies[dataset_name] if statistics_tendencies is not None else None, + (statistics_tendencies[dataset_name] if statistics_tendencies is not None else None), n_step_output, ) pre_processors[dataset_name] = pre @@ -124,7 +119,12 @@ def _build_processors( if pre_tend is not None: pre_processors_tendencies[dataset_name] = pre_tend post_processors_tendencies[dataset_name] = post_tend - return pre_processors, post_processors, pre_processors_tendencies, post_processors_tendencies + return ( + pre_processors, + post_processors, + pre_processors_tendencies, + post_processors_tendencies, + ) # --------------------------------------------------------------------------- @@ -134,72 +134,42 @@ def _build_processors( def build_anemoi_model( *, - backbone: DictConfig, - training_config: DictConfig, - data_config: DictConfig, - dataloader_config: DictConfig, - graph_config: DictConfig, - system_config: DictConfig, - **model_arch_kwargs, + config_bundle: ModelConfigBundle, + runtime_artifacts: ModelRuntimeArtifacts, + **_kwargs, ) -> ModelInterface: - """Build and return a fully constructed ModelInterface. + """Create an Anemoi model. - Called by Hydra instantiate(config.model) from train.py. All inputs come - from OmegaConf interpolations in the model yaml — no kwargs from train.py. + train.py calls this through Hydra. The trainer prepares the extra data + first and passes it in here, together with the parts of the config the + model needs. """ - - def _to_container(v): - if isinstance(v, DictConfig): - return OmegaConf.to_container(v, resolve=True) - return v - - full_config_dict = { - "training": _to_container(training_config), - "data": _to_container(data_config), - "dataloader": _to_container(dataloader_config), - "graph": _to_container(graph_config), - "system": _to_container(system_config), - "model": { - "backbone": _to_container(backbone), - **{k: _to_container(v) for k, v in model_arch_kwargs.items()}, - }, - } - config = OmegaConf.create(full_config_dict) - - # Build datamodule to obtain statistics, data_indices, supporting_arrays - datamodule = AnemoiDatasetsDataModule(config) + config = config_bundle.to_dictconfig() # Build processors - pre_processors, post_processors, pre_processors_tendencies, post_processors_tendencies = _build_processors( + ( + pre_processors, + post_processors, + pre_processors_tendencies, + post_processors_tendencies, + ) = _build_processors( config=config, - statistics=datamodule.statistics, - data_indices=datamodule.data_indices, - statistics_tendencies=datamodule.statistics_tendencies, + statistics=runtime_artifacts.statistics, + data_indices=runtime_artifacts.data_indices, + statistics_tendencies=runtime_artifacts.statistics_tendencies, ) - # Build graph (load from file or create) - graph_data = _build_graph(config) + wrapper_config = config.model.get("wrapper") or OmegaConf.create({"_target_": "anemoi.models.models.AnemoiModel"}) - # Combine supporting arrays with output-mask arrays - from anemoi.training.utils.supporting_arrays import build_combined_supporting_arrays - - supporting_arrays = build_combined_supporting_arrays( - config=config, - graph_data=graph_data, - supporting_arrays=datamodule.supporting_arrays, - ) - - # Build metadata - metadata = _build_metadata(config, datamodule) - - return AnemoiModel( + return instantiate( + wrapper_config, model_config=config, - graph_data=graph_data, - statistics=datamodule.statistics, - statistics_tendencies=datamodule.statistics_tendencies, - data_indices=datamodule.data_indices, - metadata=metadata, - supporting_arrays=supporting_arrays, + graph_data=runtime_artifacts.graph_data, + statistics=runtime_artifacts.statistics, + statistics_tendencies=runtime_artifacts.statistics_tendencies, + data_indices=runtime_artifacts.data_indices, + metadata=runtime_artifacts.metadata, + supporting_arrays=runtime_artifacts.supporting_arrays, pre_processors=pre_processors, post_processors=post_processors, pre_processors_tendencies=pre_processors_tendencies, @@ -207,89 +177,11 @@ def _to_container(v): ) -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _build_graph(config: DotDict) -> dict: - """Load or create graph data.""" - graphs = {} - dataset_configs = get_multiple_datasets_config(config.dataloader.training) - for dataset_name, dataset_config in dataset_configs.items(): - graph_path = getattr(config.system.input, "graph", None) - if graph_path and not getattr(config.graph, "overwrite", False): - graph_filename = Path(graph_path) - if graph_filename.name.endswith(".pt"): - graph_name = graph_filename.name.replace(".pt", f"_{dataset_name}.pt") - graph_filename = graph_filename.parent / graph_name - if graph_filename.exists(): - from anemoi.graphs.utils import get_distributed_device - - LOGGER.info("Loading graph data from %s", graph_filename) - graphs[dataset_name] = torch.load( - graph_filename, - map_location=get_distributed_device(), - weights_only=False, - ) - continue - - # Create new graph - from anemoi.graphs.create import GraphCreator - - graph_config = config.graph - dataset_reader_config = dataset_config.dataset_config - if isinstance(dataset_reader_config, dict): - dataset_source = dataset_reader_config.get("dataset") - else: - dataset_source = dataset_reader_config - if ( - dataset_source is not None - and hasattr(graph_config.nodes, "data") - and hasattr(graph_config.nodes.data.node_builder, "dataset") - ): - graph_config.nodes.data.node_builder.dataset = dataset_source - - save_path = None - if graph_path: - save_path = Path(graph_path) - if save_path.name.endswith(".pt"): - graph_name = save_path.name.replace(".pt", f"_{dataset_name}.pt") - save_path = save_path.parent / graph_name - - graphs[dataset_name] = GraphCreator(config=graph_config).create( - save_path=save_path, - overwrite=getattr(config.graph, "overwrite", False), - ) - - return graphs - - -def _build_metadata(config: DotDict, datamodule: AnemoiDatasetsDataModule) -> dict: - """Build inference/provenance metadata.""" - metadata_inference = { - "dataset_names": None, # populated by fill_metadata - "task": None, # set by train.py after build - } - md_dict = { - "version": "2.0", - "config": config, - "run_id": str(_uuid_module.uuid4()), - "dataset": None, - "data_indices": None, - "provenance_training": gather_provenance_info(), - "timestamp": datetime.datetime.now(tz=datetime.UTC), - "metadata_inference": metadata_inference, - "uuid": None, - } - datamodule.fill_metadata(md_dict) - - n_step_input = config.training.multistep_input - n_step_output = getattr(config.training, "multistep_output", 1) - for dataset_name in datamodule.dataset_names: - ts = md_dict["metadata_inference"][dataset_name]["timesteps"] - rel = ts["relative_date_indices_training"] - ts["input_relative_date_indices"] = rel[:n_step_input] - ts["output_relative_date_indices"] = rel[-n_step_output:] - - return map_config_to_primitives(md_dict) +def build_direct_model( + *, + config_bundle: ModelConfigBundle, + runtime_artifacts: ModelRuntimeArtifacts, + **_kwargs, +) -> ModelInterface: + """Create a model directly from ``config.model``.""" + return instantiate(config_bundle.model, runtime_artifacts=runtime_artifacts) diff --git a/training/src/anemoi/training/config/autoencoder.yaml b/training/src/anemoi/training/config/autoencoder.yaml index 47e410c8eb..c33d08532e 100644 --- a/training/src/anemoi/training/config/autoencoder.yaml +++ b/training/src/anemoi/training/config/autoencoder.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: encoder_decoder_only +- model_builder: anemoi - model: graphtransformer - training: autoencoder - _self_ diff --git a/training/src/anemoi/training/config/config.yaml b/training/src/anemoi/training/config/config.yaml index b1a4e59ec9..bf77587a70 100644 --- a/training/src/anemoi/training/config/config.yaml +++ b/training/src/anemoi/training/config/config.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: multi_scale +- model_builder: anemoi - model: gnn - training: default - _self_ diff --git a/training/src/anemoi/training/config/debug.yaml b/training/src/anemoi/training/config/debug.yaml index 5d21a835c1..ff2e50becc 100644 --- a/training/src/anemoi/training/config/debug.yaml +++ b/training/src/anemoi/training/config/debug.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: multi_scale +- model_builder: anemoi - model: gnn - training: default - _self_ diff --git a/training/src/anemoi/training/config/diffusion.yaml b/training/src/anemoi/training/config/diffusion.yaml index 6534aca132..807f9b660d 100644 --- a/training/src/anemoi/training/config/diffusion.yaml +++ b/training/src/anemoi/training/config/diffusion.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: multi_scale +- model_builder: anemoi - model: graphtransformer_diffusion - training: diffusion - _self_ diff --git a/training/src/anemoi/training/config/ensemble_crps.yaml b/training/src/anemoi/training/config/ensemble_crps.yaml index cbac4082ab..c868cfd5ee 100644 --- a/training/src/anemoi/training/config/ensemble_crps.yaml +++ b/training/src/anemoi/training/config/ensemble_crps.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation_ens - system: example - graph: encoder_decoder_only +- model_builder: anemoi - model: transformer_ens - training: ensemble - _self_ diff --git a/training/src/anemoi/training/config/hierarchical.yaml b/training/src/anemoi/training/config/hierarchical.yaml index c02fafed5b..5ea13fc4ff 100644 --- a/training/src/anemoi/training/config/hierarchical.yaml +++ b/training/src/anemoi/training/config/hierarchical.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: hierarchical_3level +- model_builder: anemoi - model: graphtransformer - training: default - _self_ diff --git a/training/src/anemoi/training/config/hierarchical_autoencoder.yaml b/training/src/anemoi/training/config/hierarchical_autoencoder.yaml index a7c5ec1bb7..0360f272f2 100644 --- a/training/src/anemoi/training/config/hierarchical_autoencoder.yaml +++ b/training/src/anemoi/training/config/hierarchical_autoencoder.yaml @@ -5,6 +5,7 @@ defaults: - datamodule: single - hardware: example - graph: hierarchical_2level_encoder_decoder_only +- model_builder: anemoi - model: graphtransformer - training: autoencoder - _self_ diff --git a/training/src/anemoi/training/config/interpolator.yaml b/training/src/anemoi/training/config/interpolator.yaml index 9025358dd1..13cc031c07 100644 --- a/training/src/anemoi/training/config/interpolator.yaml +++ b/training/src/anemoi/training/config/interpolator.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: multi_scale +- model_builder: anemoi - model: graphtransformer - training: default - _self_ diff --git a/training/src/anemoi/training/config/lam.yaml b/training/src/anemoi/training/config/lam.yaml index 23b65c9b2d..a1e92abc71 100644 --- a/training/src/anemoi/training/config/lam.yaml +++ b/training/src/anemoi/training/config/lam.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: limited_area +- model_builder: anemoi - model: graphtransformer - training: lam - _self_ diff --git a/training/src/anemoi/training/config/model/gnn.yaml b/training/src/anemoi/training/config/model/gnn.yaml index 59684ec1c2..3aa2a286c1 100644 --- a/training/src/anemoi/training/config/model/gnn.yaml +++ b/training/src/anemoi/training/config/model/gnn.yaml @@ -1,12 +1,3 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 512 cpu_offload: False diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index f0c1879238..d5d5cea099 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -1,12 +1,3 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 cpu_offload: False diff --git a/training/src/anemoi/training/config/model/graphtransformer_diffusion.yaml b/training/src/anemoi/training/config/model/graphtransformer_diffusion.yaml index 8b5037fdce..0e9de541d3 100644 --- a/training/src/anemoi/training/config/model/graphtransformer_diffusion.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer_diffusion.yaml @@ -1,17 +1,11 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 cpu_offload: False keep_batch_sharded: True +wrapper: + _target_: anemoi.models.models.AnemoiDiffusionModel + backbone: _target_: anemoi.models.models.AnemoiDiffusionModelEncProcDec hidden_nodes_name: "hidden" diff --git a/training/src/anemoi/training/config/model/graphtransformer_diffusiontend.yaml b/training/src/anemoi/training/config/model/graphtransformer_diffusiontend.yaml index 9891b56c18..6be707309e 100644 --- a/training/src/anemoi/training/config/model/graphtransformer_diffusiontend.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer_diffusiontend.yaml @@ -1,18 +1,12 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 condition_on_residual: False cpu_offload: False keep_batch_sharded: True +wrapper: + _target_: anemoi.models.models.AnemoiDiffusionTendencyModel + backbone: _target_: anemoi.models.models.AnemoiDiffusionTendModelEncProcDec hidden_nodes_name: "hidden" diff --git a/training/src/anemoi/training/config/model/graphtransformer_ens.yaml b/training/src/anemoi/training/config/model/graphtransformer_ens.yaml index 0ad6e8a60d..a4c6db5707 100644 --- a/training/src/anemoi/training/config/model/graphtransformer_ens.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer_ens.yaml @@ -1,12 +1,3 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 condition_on_residual: False output_mask: diff --git a/training/src/anemoi/training/config/model/point_wise.yaml b/training/src/anemoi/training/config/model/point_wise.yaml index e98e6d9a81..854e7b88e0 100644 --- a/training/src/anemoi/training/config/model/point_wise.yaml +++ b/training/src/anemoi/training/config/model/point_wise.yaml @@ -1,12 +1,3 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 cpu_offload: False diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index 6bd1597306..6fa1f37890 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -1,12 +1,3 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 cpu_offload: False diff --git a/training/src/anemoi/training/config/model/transformer_diffusion.yaml b/training/src/anemoi/training/config/model/transformer_diffusion.yaml index f4851133f6..e5f2890ac4 100644 --- a/training/src/anemoi/training/config/model/transformer_diffusion.yaml +++ b/training/src/anemoi/training/config/model/transformer_diffusion.yaml @@ -1,17 +1,11 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 cpu_offload: False keep_batch_sharded: True +wrapper: + _target_: anemoi.models.models.AnemoiDiffusionModel + backbone: _target_: anemoi.models.models.AnemoiDiffusionModelEncProcDec hidden_nodes_name: "hidden" diff --git a/training/src/anemoi/training/config/model/transformer_diffusiontend.yaml b/training/src/anemoi/training/config/model/transformer_diffusiontend.yaml index e0cd08e36a..f5d4116f72 100644 --- a/training/src/anemoi/training/config/model/transformer_diffusiontend.yaml +++ b/training/src/anemoi/training/config/model/transformer_diffusiontend.yaml @@ -1,18 +1,12 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 condition_on_residual: False cpu_offload: False keep_batch_sharded: True +wrapper: + _target_: anemoi.models.models.AnemoiDiffusionTendencyModel + backbone: _target_: anemoi.models.models.AnemoiDiffusionTendModelEncProcDec hidden_nodes_name: "hidden" diff --git a/training/src/anemoi/training/config/model/transformer_ens.yaml b/training/src/anemoi/training/config/model/transformer_ens.yaml index 4c26147251..15a9a9fffd 100644 --- a/training/src/anemoi/training/config/model/transformer_ens.yaml +++ b/training/src/anemoi/training/config/model/transformer_ens.yaml @@ -1,12 +1,3 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 condition_on_residual: False cpu_offload: False diff --git a/training/src/anemoi/training/config/model/transformer_transformermapper.yaml b/training/src/anemoi/training/config/model/transformer_transformermapper.yaml index b7a511394a..8519113529 100644 --- a/training/src/anemoi/training/config/model/transformer_transformermapper.yaml +++ b/training/src/anemoi/training/config/model/transformer_transformermapper.yaml @@ -1,12 +1,3 @@ -_target_: anemoi.training.builder.build_anemoi_model -_recursive_: false - -training_config: ${training} -data_config: ${data} -dataloader_config: ${dataloader} -graph_config: ${graph} -system_config: ${system} - num_channels: 1024 cpu_offload: False diff --git a/training/src/anemoi/training/config/model_builder/anemoi.yaml b/training/src/anemoi/training/config/model_builder/anemoi.yaml new file mode 100644 index 0000000000..0f5561130e --- /dev/null +++ b/training/src/anemoi/training/config/model_builder/anemoi.yaml @@ -0,0 +1,2 @@ +_target_: anemoi.training.builder.build_anemoi_model +_recursive_: false diff --git a/training/src/anemoi/training/config/model_builder/direct.yaml b/training/src/anemoi/training/config/model_builder/direct.yaml new file mode 100644 index 0000000000..77069b0e94 --- /dev/null +++ b/training/src/anemoi/training/config/model_builder/direct.yaml @@ -0,0 +1,2 @@ +_target_: anemoi.training.builder.build_direct_model +_recursive_: false diff --git a/training/src/anemoi/training/config/multi.yaml b/training/src/anemoi/training/config/multi.yaml index f348aa2d5c..2087004a02 100644 --- a/training/src/anemoi/training/config/multi.yaml +++ b/training/src/anemoi/training/config/multi.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation_multi - system: example - graph: multi +- model_builder: anemoi - model: graphtransformer - training: multi - _self_ diff --git a/training/src/anemoi/training/config/naive.yaml b/training/src/anemoi/training/config/naive.yaml new file mode 100644 index 0000000000..c4f5e7da1b --- /dev/null +++ b/training/src/anemoi/training/config/naive.yaml @@ -0,0 +1,12 @@ +defaults: +- data: zarr +- dataloader: native_grid +- diagnostics: evaluation +- system: example +- graph: multi_scale +- model_builder: direct +- model: naive +- training: default +- _self_ + +config_validation: True diff --git a/training/src/anemoi/training/config/point_wise.yaml b/training/src/anemoi/training/config/point_wise.yaml index d964b4df30..e5f2523691 100644 --- a/training/src/anemoi/training/config/point_wise.yaml +++ b/training/src/anemoi/training/config/point_wise.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: point_wise +- model_builder: anemoi - model: point_wise - training: default - _self_ diff --git a/training/src/anemoi/training/config/stretched.yaml b/training/src/anemoi/training/config/stretched.yaml index 467bd20936..3f06495b92 100644 --- a/training/src/anemoi/training/config/stretched.yaml +++ b/training/src/anemoi/training/config/stretched.yaml @@ -4,6 +4,7 @@ defaults: - diagnostics: evaluation - system: example - graph: stretched_grid +- model_builder: anemoi - model: graphtransformer - training: stretched - _self_ diff --git a/training/src/anemoi/training/config_bundle.py b/training/src/anemoi/training/config_bundle.py new file mode 100644 index 0000000000..ac493dda9c --- /dev/null +++ b/training/src/anemoi/training/config_bundle.py @@ -0,0 +1,101 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING +from typing import Any + +from omegaconf import DictConfig +from omegaconf import OmegaConf + +if TYPE_CHECKING: + from omegaconf import ListConfig + + +_MODEL_CONFIG_EXCLUDE_KEYS = { + "training_config", + "data_config", + "dataloader_config", + "graph_config", + "system_config", + "config_bundle", +} + + +def _clone_config_section(section: Any) -> DictConfig | ListConfig: + return OmegaConf.create(OmegaConf.to_container(section, resolve=True)) + + +def _clean_model_section(model_config: DictConfig) -> DictConfig: + clean_model_dict = { + key: OmegaConf.to_container(value, resolve=True) if OmegaConf.is_config(value) else value + for key, value in model_config.items() + if key not in _MODEL_CONFIG_EXCLUDE_KEYS + } + return OmegaConf.create(clean_model_dict) + + +@dataclass(frozen=True) +class ModelConfigBundle: + """Parts of the config used to create the model.""" + + training: DictConfig + data: DictConfig + model: DictConfig + + @classmethod + def from_root_config(cls, root_config: DictConfig) -> ModelConfigBundle: + return cls( + training=_clone_config_section(root_config.training), + data=_clone_config_section(root_config.data), + model=_clean_model_section(root_config.model), + ) + + def to_dictconfig(self) -> DictConfig: + return OmegaConf.create( + { + "training": OmegaConf.to_container(self.training, resolve=True), + "data": OmegaConf.to_container(self.data, resolve=True), + "model": OmegaConf.to_container(self.model, resolve=True), + }, + ) + + +@dataclass(frozen=True) +class TaskConfigBundle: + """Parts of the config used by the training task.""" + + training: DictConfig + system: DictConfig + dataloader: DictConfig + graph: DictConfig + model: DictConfig + + @classmethod + def from_root_config(cls, root_config: DictConfig) -> TaskConfigBundle: + return cls( + training=_clone_config_section(root_config.training), + system=_clone_config_section(root_config.system), + dataloader=_clone_config_section(root_config.dataloader), + graph=_clone_config_section(root_config.graph), + model=_clean_model_section(root_config.model), + ) + + def to_dictconfig(self) -> DictConfig: + return OmegaConf.create( + { + "training": OmegaConf.to_container(self.training, resolve=True), + "system": OmegaConf.to_container(self.system, resolve=True), + "dataloader": OmegaConf.to_container(self.dataloader, resolve=True), + "graph": OmegaConf.to_container(self.graph, resolve=True), + "model": OmegaConf.to_container(self.model, resolve=True), + }, + ) diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot.py b/training/src/anemoi/training/diagnostics/callbacks/plot.py index 70caad49a2..4f50f520c3 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot.py @@ -50,6 +50,22 @@ LOGGER = logging.getLogger(__name__) +def _get_plot_backbone(pl_module: pl.LightningModule) -> object: + """Return the graph model used by plotting callbacks. + + The current training path wraps graph backbones in ``AnemoiModel`` and + exposes them under ``.backbone``. + """ + model = pl_module.model.module if hasattr(pl_module.model, "module") else pl_module.model + return model.backbone + + +def _get_dataset_latlons(pl_module: pl.LightningModule, dataset_name: str) -> torch.Tensor: + """Return lat/lon node coordinates for one dataset.""" + backbone = _get_plot_backbone(pl_module) + return backbone._graph_data[dataset_name].x.detach() + + class BasePlotCallback(Callback, ABC): """Factory for creating a callback that plots data to Experiment Logging.""" @@ -480,7 +496,7 @@ def _plot( for name in self.parameters } if self.latlons is None: - self.latlons = pl_module.model.backbone._graph_data[pl_module.model.backbone._graph_name_data].x.detach() + self.latlons = _get_dataset_latlons(pl_module, self.dataset_names[0]) self.latlons = np.rad2deg(self.latlons.cpu().numpy()) assert batch.shape[1] >= self.max_rollout + pl_module.n_step_input, ( @@ -1116,9 +1132,7 @@ def process( self.latlons = {} if dataset_name not in self.latlons: - self.latlons[dataset_name] = pl_module.model.backbone._graph_data[dataset_name][ - pl_module.model.backbone._graph_name_data - ].x.detach() + self.latlons[dataset_name] = _get_dataset_latlons(pl_module, dataset_name) self.latlons[dataset_name] = np.rad2deg(self.latlons[dataset_name].cpu().numpy()) # All tasks return (loss, metrics, list of per-step dicts) from _step; on_validation_batch_end enforces list. diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot_ens.py b/training/src/anemoi/training/diagnostics/callbacks/plot_ens.py index 196bce101f..9574e66efa 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot_ens.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot_ens.py @@ -22,6 +22,7 @@ from anemoi.training.diagnostics.callbacks.plot import PlotLoss as _PlotLoss from anemoi.training.diagnostics.callbacks.plot import PlotSample as _PlotSample from anemoi.training.diagnostics.callbacks.plot import PlotSpectrum as _PlotSpectrum +from anemoi.training.diagnostics.callbacks.plot import _get_dataset_latlons if TYPE_CHECKING: from typing import Any @@ -102,9 +103,7 @@ def process( members = [members] if dataset_name not in self.latlons: - self.latlons[dataset_name] = pl_module.model.backbone._graph_data[dataset_name][ - pl_module.model.backbone._graph_name_data - ].x.detach() + self.latlons[dataset_name] = _get_dataset_latlons(pl_module, dataset_name) self.latlons[dataset_name] = np.rad2deg(self.latlons[dataset_name].cpu().numpy()) total_targets = pl_module.plot_adapter.get_total_plot_targets() @@ -276,7 +275,10 @@ def _plot( else self.config.data.datasets[dataset_name].diagnostic ) plot_parameters_dict = { - pl_module.data_indices[dataset_name].model.output.name_to_index[name]: (name, name in diagnostics) + pl_module.data_indices[dataset_name].model.output.name_to_index[name]: ( + name, + name in diagnostics, + ) for name in self.parameters } diff --git a/training/src/anemoi/training/runtime.py b/training/src/anemoi/training/runtime.py new file mode 100644 index 0000000000..65c729294c --- /dev/null +++ b/training/src/anemoi/training/runtime.py @@ -0,0 +1,51 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from torch_geometric.data import HeteroData + + +@dataclass(frozen=True) +class TaskRuntimeArtifacts: + """Data prepared by the trainer and passed to the Lightning task.""" + + graph_data: HeteroData + statistics: dict + statistics_tendencies: dict | None + data_indices: dict + metadata: dict + supporting_arrays: dict + + +@dataclass(frozen=True) +class ModelRuntimeArtifacts: + """Data prepared by the trainer and passed when creating the model.""" + + graph_data: HeteroData + statistics: dict + statistics_tendencies: dict | None + data_indices: dict + metadata: dict + supporting_arrays: dict + + def to_task_runtime_artifacts(self) -> TaskRuntimeArtifacts: + """Return the same data in the form used by tasks.""" + return TaskRuntimeArtifacts( + graph_data=self.graph_data, + statistics=self.statistics, + statistics_tendencies=self.statistics_tendencies, + data_indices=self.data_indices, + metadata=self.metadata, + supporting_arrays=self.supporting_arrays, + ) diff --git a/training/src/anemoi/training/schemas/base_schema.py b/training/src/anemoi/training/schemas/base_schema.py index f76d3cc5fb..298086737b 100644 --- a/training/src/anemoi/training/schemas/base_schema.py +++ b/training/src/anemoi/training/schemas/base_schema.py @@ -98,6 +98,8 @@ class BaseSchema(SchemaCommonMixin, BaseModel): """System configuration, including filesystem and hardware specification.""" graph: BaseGraphSchema """Graph configuration.""" + model_builder: Any | None = None + """Settings for creating the model.""" model: ModelSchema """Model configuration.""" training: TrainingSchema @@ -123,6 +125,8 @@ class UnvalidatedBaseSchema(SchemaCommonMixin, PydanticBaseModel): """Hardware configuration.""" graph: Any """Graph configuration.""" + model_builder: Any = None + """Settings for creating the model.""" model: Any """Model configuration.""" training: Any diff --git a/training/src/anemoi/training/train/tasks/autoencoder.py b/training/src/anemoi/training/train/tasks/autoencoder.py index 5a4f57ba65..8eabc5fc69 100644 --- a/training/src/anemoi/training/train/tasks/autoencoder.py +++ b/training/src/anemoi/training/train/tasks/autoencoder.py @@ -20,11 +20,10 @@ from collections.abc import Mapping import torch - from omegaconf import DictConfig - from torch_geometric.data import HeteroData - from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.interface import ModelInterface + from anemoi.training.config_bundle import TaskConfigBundle + from anemoi.training.runtime import TaskRuntimeArtifacts LOGGER = logging.getLogger(__name__) @@ -39,11 +38,8 @@ def __init__( self, *, model: ModelInterface, - config: DictConfig, - graph_data: HeteroData, - statistics: dict, - statistics_tendencies: dict, - data_indices: IndexCollection, + config_bundle: TaskConfigBundle, + runtime_artifacts: TaskRuntimeArtifacts, **kwargs, ) -> None: """Initialize graph neural network interpolator. @@ -51,23 +47,16 @@ def __init__( Parameters ---------- model : ModelInterface - config : DictConfig - Job configuration - graph_data : HeteroData - Graph object - statistics : dict - Statistics of the training data - data_indices : IndexCollection - Indices of the training data, + config_bundle : TaskConfigBundle + Parts of the config used by this task. + runtime_artifacts : TaskRuntimeArtifacts + Data prepared by the trainer for this task. """ super().__init__( model=model, - config=config, - graph_data=graph_data, - statistics=statistics, - statistics_tendencies=statistics_tendencies, - data_indices=data_indices, + config_bundle=config_bundle, + runtime_artifacts=runtime_artifacts, **kwargs, ) diff --git a/training/src/anemoi/training/train/tasks/base.py b/training/src/anemoi/training/train/tasks/base.py index ef24bbd631..7d2ec0cd09 100644 --- a/training/src/anemoi/training/train/tasks/base.py +++ b/training/src/anemoi/training/train/tasks/base.py @@ -29,7 +29,6 @@ from anemoi.models.distributed.balanced_partition import get_partition_range from anemoi.models.distributed.graph import gather_tensor from anemoi.models.distributed.shapes import apply_shard_shapes -from anemoi.models.interface import ModelInterface from anemoi.models.utils.config import get_multiple_datasets_config from anemoi.training.losses import get_loss_function from anemoi.training.losses.base import BaseLoss @@ -48,7 +47,9 @@ from torch.distributed.distributed_c10d import ProcessGroup from anemoi.models.data_indices.collection import IndexCollection - from anemoi.training.schemas.base_schema import BaseSchema + from anemoi.models.interface import ModelInterface + from anemoi.training.config_bundle import TaskConfigBundle + from anemoi.training.runtime import TaskRuntimeArtifacts LOGGER = logging.getLogger(__name__) @@ -76,20 +77,10 @@ class BaseGraphModule(pl.LightningModule, ABC): Parameters ---------- - config : BaseSchema - Configuration object defining all parameters. - graph_data : HeteroData - Graph-structured input data containing node and edge features, keyed by dataset name. - statistics : dict - Dictionary of training statistics (mean, std, etc.) used for normalization. - statistics_tendencies : dict - Statistics related to tendencies (if used). - data_indices : dict[str, IndexCollection] - Maps feature names to index ranges used for training and loss functions. - metadata : dict - Dictionary with metadata such as dataset provenance and variable descriptions. - supporting_arrays : dict - Numpy arrays (e.g., topography, masks) needed during inference and stored in checkpoints. + config_bundle : TaskConfigBundle + Parts of the config used by this task. + runtime_artifacts : TaskRuntimeArtifacts + Graph data, statistics, indices, metadata, and extra arrays prepared by the trainer. Attributes ---------- @@ -137,31 +128,28 @@ def __init__( self, *, model: ModelInterface, - config: BaseSchema, - graph_data: HeteroData, - statistics: dict, - statistics_tendencies: dict, - data_indices: dict[str, IndexCollection], - metadata: dict, - supporting_arrays: dict, + config_bundle: TaskConfigBundle, + runtime_artifacts: TaskRuntimeArtifacts, ) -> None: """Initialize graph neural network forecaster. Parameters ---------- - config : DictConfig - Job configuration - graph_data : HeteroData - Graph objects keyed by dataset name - statistics : dict - Statistics of the training data - data_indices : dict[str, IndexCollection] - Indices of the training data, - metadata : dict - Provenance and inference metadata. + config_bundle : TaskConfigBundle + Parts of the config used by this task. + runtime_artifacts : TaskRuntimeArtifacts + Graph data, statistics, indices, metadata, and extra arrays prepared by the trainer. """ super().__init__() + config = config_bundle.to_dictconfig() + + graph_data = runtime_artifacts.graph_data + statistics = runtime_artifacts.statistics + statistics_tendencies = runtime_artifacts.statistics_tendencies + data_indices = runtime_artifacts.data_indices + metadata = runtime_artifacts.metadata + supporting_arrays = runtime_artifacts.supporting_arrays assert isinstance(graph_data, HeteroData), "graph_data must be a HeteroData object" assert isinstance(data_indices, dict), "data_indices must be a dict keyed by dataset name" @@ -181,7 +169,11 @@ def __init__( raise AttributeError(msg) self.model = model + self.config_bundle = config_bundle self.config = config + self.task_runtime_artifacts = runtime_artifacts + self.graph_data = graph_data + self.statistics = statistics self.metadata = metadata self.supporting_arrays = supporting_arrays @@ -263,7 +255,11 @@ def __init__( self.is_first_step = True self.n_step_input = config.training.multistep_input self.n_step_output = config.training.multistep_output # defaults to 1 via pydantic - LOGGER.info("GraphModule with n_step_input=%s and n_step_output=%s", self.n_step_input, self.n_step_output) + LOGGER.info( + "GraphModule with n_step_input=%s and n_step_output=%s", + self.n_step_input, + self.n_step_output, + ) self.lr = ( config.system.hardware.num_nodes * config.system.hardware.num_gpus_per_node @@ -409,7 +405,10 @@ def _update_checkpoint_state_dict_for_load(self, checkpoint: dict[str, Any]) -> if update_states: processor_prefixes += ("model.pre_processors.", "model.post_processors.") if update_tendencies: - processor_prefixes += ("model.pre_processors_tendencies.", "model.post_processors_tendencies.") + processor_prefixes += ( + "model.pre_processors_tendencies.", + "model.post_processors_tendencies.", + ) if not processor_prefixes: return @@ -816,9 +815,13 @@ def _prepare_loss_scalers(self) -> None: self.update_scalers(callback=AvailableCallbacks.ON_BATCH_START) return - @abstractmethod def fill_metadata(self, metadata: dict) -> None: - """Fill inference metadata with task-specific timestep indices.""" + """Fill inference metadata with default input/output timestep indices.""" + for dataset_name in self.dataset_names: + ts = metadata["metadata_inference"][dataset_name]["timesteps"] + rel = ts["relative_date_indices_training"] + ts["input_relative_date_indices"] = rel[: self.n_step_input] + ts["output_relative_date_indices"] = rel[-self.n_step_output :] @abstractmethod def _step( @@ -1038,7 +1041,9 @@ def lr_scheduler_step(self, scheduler: CosineLRScheduler, metric: None = None) - def on_train_epoch_end(self) -> None: pass - def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict[str, Any]]]: + def configure_optimizers( + self, + ) -> tuple[list[torch.optim.Optimizer], list[dict[str, Any]]]: """Create optimizer and LR scheduler based on Hydra config.""" optimizer = self._create_optimizer_from_config(self.config.training.optimizer) scheduler = self._create_scheduler(optimizer) diff --git a/training/src/anemoi/training/train/tasks/diffusionforecaster.py b/training/src/anemoi/training/train/tasks/diffusionforecaster.py index 79392d3a2f..1ffba2f852 100644 --- a/training/src/anemoi/training/train/tasks/diffusionforecaster.py +++ b/training/src/anemoi/training/train/tasks/diffusionforecaster.py @@ -12,24 +12,35 @@ import logging from typing import TYPE_CHECKING +from typing import Any import torch from torch.utils.checkpoint import checkpoint +from anemoi.models.interface import DiffusionModelInterface +from anemoi.models.interface import DiffusionTendencyModelInterface from anemoi.models.preprocessing import StepwiseProcessors from .base import BaseGraphModule if TYPE_CHECKING: - from torch_geometric.data import HeteroData + from omegaconf import DictConfig - from anemoi.models.data_indices.collection import IndexCollection - from anemoi.models.interface import ModelInterface - from anemoi.training.schemas.base_schema import BaseSchema + from anemoi.training.config_bundle import TaskConfigBundle + from anemoi.training.runtime import TaskRuntimeArtifacts LOGGER = logging.getLogger(__name__) +def _get_diffusion_config(config: DictConfig) -> Any: + backbone_config = config.model.get("backbone") + if backbone_config is not None and "diffusion" in backbone_config: + return backbone_config.diffusion + + msg = "Expected diffusion settings under `model.backbone.diffusion`." + raise KeyError(msg) + + class BaseDiffusionForecaster(BaseGraphModule): """Base class for diffusion forecasters.""" @@ -38,30 +49,28 @@ class BaseDiffusionForecaster(BaseGraphModule): def __init__( self, *, - model: ModelInterface, - config: BaseSchema, - graph_data: HeteroData, - statistics: dict, - statistics_tendencies: dict, - data_indices: dict[str, IndexCollection], + model: DiffusionModelInterface, + config_bundle: TaskConfigBundle, + runtime_artifacts: TaskRuntimeArtifacts, **kwargs, ) -> None: - + """Initialize the diffusion forecaster.""" super().__init__( model=model, - config=config, - graph_data=graph_data, - statistics=statistics, - statistics_tendencies=statistics_tendencies, - data_indices=data_indices, + config_bundle=config_bundle, + runtime_artifacts=runtime_artifacts, **kwargs, ) + if not isinstance(model, DiffusionModelInterface): + msg = f"{self.__class__.__name__} requires a diffusion-capable model interface." + raise TypeError(msg) - self.rho = config.model.nn.diffusion.rho + self.rho = _get_diffusion_config(self.config).rho from anemoi.training.diagnostics.callbacks.plot_adapter import DiffusionPlotAdapter self._plot_adapter = DiffusionPlotAdapter(self) + self.fill_metadata(self.metadata) def get_input(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Get input tensor shape for diffusion model.""" @@ -98,7 +107,7 @@ def forward( y_noised: dict[str, torch.Tensor], sigma: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: - return self.model.backbone.fwd_with_preconditioning( + return self.model.forward_with_preconditioning( x, y_noised, sigma, @@ -224,11 +233,12 @@ def _step( # get noise level and associated loss weights shapes = {k: y_.shape for k, y_ in y.items()} + sigma_max, sigma_min, sigma_data = self.model.get_diffusion_parameters() sigma, noise_weights = self._get_noise_level( shape=shapes, - sigma_max=self.model.backbone.sigma_max, - sigma_min=self.model.backbone.sigma_min, - sigma_data=self.model.backbone.sigma_data, + sigma_max=sigma_max, + sigma_min=sigma_min, + sigma_data=sigma_data, rho=self.rho, device=next(iter(batch.values())).device, ) @@ -254,23 +264,20 @@ class GraphDiffusionTendForecaster(BaseDiffusionForecaster): def __init__( self, *, - model: ModelInterface, - config: BaseSchema, - graph_data: HeteroData, - statistics: dict, - statistics_tendencies: dict, - data_indices: dict[str, IndexCollection], + model: DiffusionTendencyModelInterface, + config_bundle: TaskConfigBundle, + runtime_artifacts: TaskRuntimeArtifacts, **kwargs, ) -> None: super().__init__( model=model, - config=config, - graph_data=graph_data, - statistics=statistics, - statistics_tendencies=statistics_tendencies, - data_indices=data_indices, + config_bundle=config_bundle, + runtime_artifacts=runtime_artifacts, **kwargs, ) + if not isinstance(model, DiffusionTendencyModelInterface): + msg = f"{self.__class__.__name__} requires a diffusion-tendency model interface." + raise TypeError(msg) self._tendency_pre_processors: dict[str, object] = {} self._tendency_post_processors: dict[str, object] = {} self._validate_tendency_processors() @@ -279,12 +286,6 @@ def _validate_tendency_processors(self) -> None: stats = self.statistics_tendencies assert stats is not None, "Tendency statistics are required for diffusion tendency models." - pre_processors_tendencies = getattr(self.model, "pre_processors_tendencies", None) - post_processors_tendencies = getattr(self.model, "post_processors_tendencies", None) - assert ( - pre_processors_tendencies is not None and post_processors_tendencies is not None - ), "Per-step tendency processors are required for multi-output diffusion tendency models." - def _wrap_if_needed( kind: str, proc: object, @@ -317,16 +318,11 @@ def _wrap_if_needed( assert all( lead_time in dataset_stats for lead_time in lead_times ), "Missing tendency statistics for one or more output steps." - - assert ( - dataset_name in pre_processors_tendencies - ), "Per-step tendency processors are required for multi-output diffusion tendency models." - assert ( - dataset_name in post_processors_tendencies - ), "Per-step tendency processors are required for multi-output diffusion tendency models." - - pre_tend = pre_processors_tendencies[dataset_name] - post_tend = post_processors_tendencies[dataset_name] + try: + pre_tend, post_tend = self.model.get_tendency_processors(dataset_name) + except (AttributeError, KeyError) as exc: + msg = "Per-step tendency processors are required for multi-output diffusion tendency models." + raise AttributeError(msg) from exc pre_tend = _wrap_if_needed("pre", pre_tend, dataset_name, lead_times) post_tend = _wrap_if_needed("post", post_tend, dataset_name, lead_times) assert ( @@ -355,14 +351,12 @@ def _compute_tendency_target( for step, pre_proc in enumerate(pre_tend): y_step = y_dataset[:, step : step + 1] x_ref_step = x_ref[dataset_name].unsqueeze(1) - tendency_step = self.model.backbone.compute_tendency( - {dataset_name: y_step}, - {dataset_name: x_ref_step}, - {dataset_name: self.model.pre_processors[dataset_name]}, - {dataset_name: pre_proc}, - input_post_processor={dataset_name: self.model.post_processors[dataset_name]}, - skip_imputation=True, - )[dataset_name] + tendency_step = self.model.compute_tendency_step( + dataset_name=dataset_name, + y_step=y_step, + x_ref_step=x_ref_step, + tendency_pre_processor=pre_proc, + ) tendency_steps.append(tendency_step) tendencies[dataset_name] = torch.cat(tendency_steps, dim=1) return tendencies @@ -380,21 +374,15 @@ def _reconstruct_state( for step, post_proc in enumerate(post_tend): x_ref_step = x_ref[dataset_name].unsqueeze(1) tendency_step = tendency_dataset[:, step : step + 1] - state_step = self.model.backbone.add_tendency_to_state( - {dataset_name: x_ref_step}, - {dataset_name: tendency_step}, - {dataset_name: self.model.post_processors[dataset_name]}, - {dataset_name: post_proc}, - output_pre_processor={dataset_name: self.model.pre_processors[dataset_name]}, - skip_imputation=True, - )[dataset_name] + state_step = self.model.add_tendency_to_state_step( + dataset_name=dataset_name, + x_ref_step=x_ref_step, + tendency_step=tendency_step, + tendency_post_processor=post_proc, + ) state_steps.append(state_step) out_dataset = torch.cat(state_steps, dim=1) - out_dataset = self.model.backbone._apply_imputer_inverse( - self.model.post_processors, - dataset_name, - out_dataset, - ) + out_dataset = self.model.apply_imputer_inverse(dataset_name, out_dataset) states[dataset_name] = out_dataset return states @@ -441,11 +429,7 @@ def compute_dataset_loss_metrics( y_pred_state_full, y_state_full, grid_shard_slice = self._prepare_tensors_for_loss( y_pred_state[dataset_name], - self.model.backbone._apply_imputer_inverse( - self.model.post_processors, - dataset_name, - y_state[dataset_name], - ), + self.model.apply_imputer_inverse(dataset_name, y_state[dataset_name]), validation_mode=validation_mode, dataset_name=dataset_name, ) @@ -487,15 +471,7 @@ def _step( x = self.get_input(batch) # (bs, n_step_input, ens, latlon, nvar) y = self.get_target(batch) # (bs, n_step_output, ens, latlon, nvar) - pre_processors_tendencies = getattr(self.model, "pre_processors_tendencies", None) - if pre_processors_tendencies is None or len(pre_processors_tendencies) == 0: - msg = ( - "pre_processors_tendencies not found. This is required for tendency-based diffusion models. " - "Ensure that statistics_tendencies is provided during model initialization." - ) - raise AttributeError(msg) - - x_ref = self.model.backbone.apply_reference_state_truncation( + x_ref = self.model.apply_reference_state_truncation( x, self.grid_shard_shapes, self.model_comm_group, @@ -507,11 +483,12 @@ def _step( # get noise level and associated loss weights shapes = {k: target.shape for k, target in tendency_target.items()} + sigma_max, sigma_min, sigma_data = self.model.get_diffusion_parameters() sigma, noise_weights = self._get_noise_level( shape=shapes, - sigma_max=self.model.backbone.sigma_max, - sigma_min=self.model.backbone.sigma_min, - sigma_data=self.model.backbone.sigma_data, + sigma_max=sigma_max, + sigma_min=sigma_min, + sigma_data=sigma_data, rho=self.rho, device=next(iter(batch.values())).device, ) diff --git a/training/src/anemoi/training/train/tasks/ensforecaster.py b/training/src/anemoi/training/train/tasks/ensforecaster.py index 1741384c61..9a1534af10 100644 --- a/training/src/anemoi/training/train/tasks/ensforecaster.py +++ b/training/src/anemoi/training/train/tasks/ensforecaster.py @@ -22,11 +22,11 @@ if TYPE_CHECKING: from collections.abc import Generator - from omegaconf import DictConfig from torch.distributed.distributed_c10d import ProcessGroup - from torch_geometric.data import HeteroData from anemoi.models.interface import ModelInterface + from anemoi.training.config_bundle import TaskConfigBundle + from anemoi.training.runtime import TaskRuntimeArtifacts LOGGER = logging.getLogger(__name__) @@ -40,11 +40,8 @@ def __init__( self, *, model: ModelInterface, - config: DictConfig, - graph_data: HeteroData, - statistics: dict, - statistics_tendencies: dict, - data_indices: dict, + config_bundle: TaskConfigBundle, + runtime_artifacts: TaskRuntimeArtifacts, **kwargs, ) -> None: """Initialize graph neural network forecaster. @@ -52,24 +49,20 @@ def __init__( Parameters ---------- model : ModelInterface - config : DictConfig - Job configuration - statistics : dict - Statistics of the training data - data_indices : dict - Indices of the training data, + config_bundle : TaskConfigBundle + Parts of the config used by this task. + runtime_artifacts : TaskRuntimeArtifacts + Data prepared by the trainer for this task. """ super().__init__( model=model, - config=config, - graph_data=graph_data, - statistics=statistics, - statistics_tendencies=statistics_tendencies, - data_indices=data_indices, + config_bundle=config_bundle, + runtime_artifacts=runtime_artifacts, **kwargs, ) # num_gpus_per_ensemble >= 1 and num_gpus_per_ensemble >= num_gpus_per_model (as per the DDP strategy) + config = self.config self.model_comm_group_size = config.system.hardware.num_gpus_per_model num_gpus_per_model = config.system.hardware.num_gpus_per_model num_gpus_per_ensemble = config.system.hardware.num_gpus_per_ensemble diff --git a/training/src/anemoi/training/train/tasks/interpolator.py b/training/src/anemoi/training/train/tasks/interpolator.py index 233af3ce8e..c4c96d98c6 100644 --- a/training/src/anemoi/training/train/tasks/interpolator.py +++ b/training/src/anemoi/training/train/tasks/interpolator.py @@ -10,27 +10,23 @@ from __future__ import annotations import logging +from dataclasses import replace from operator import itemgetter from typing import TYPE_CHECKING import torch -from omegaconf import DictConfig from omegaconf import open_dict from torch.utils.checkpoint import checkpoint -from torch_geometric.data import HeteroData -from anemoi.models.data_indices.collection import IndexCollection from anemoi.training.diagnostics.callbacks.plot_adapter import InterpolatorMultiOutPlotAdapter from anemoi.training.train.tasks.base import BaseGraphModule if TYPE_CHECKING: from collections.abc import Mapping - from omegaconf import DictConfig - from torch_geometric.data import HeteroData - - from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.interface import ModelInterface + from anemoi.training.config_bundle import TaskConfigBundle + from anemoi.training.runtime import TaskRuntimeArtifacts LOGGER = logging.getLogger(__name__) @@ -45,11 +41,8 @@ def __init__( self, *, model: ModelInterface, - config: DictConfig, - graph_data: dict[str, HeteroData], - statistics: dict, - statistics_tendencies: dict, - data_indices: dict[str, IndexCollection], + config_bundle: TaskConfigBundle, + runtime_artifacts: TaskRuntimeArtifacts, **kwargs, ) -> None: """Initialize graph neural network interpolator. @@ -57,30 +50,25 @@ def __init__( Parameters ---------- model : ModelInterface - config : DictConfig - Job configuration - graph_data : dict[str, HeteroData] - Graph objects keyed by dataset name - statistics : dict - Statistics of the training data - data_indices : dict[str, IndexCollection] - Indices of the training data + config_bundle : TaskConfigBundle + Parts of the config used by this task. + runtime_artifacts : TaskRuntimeArtifacts + Data prepared by the trainer for this task. """ + config = config_bundle.to_dictconfig() with open_dict(config.training): config.training.multistep_output = len(config.training.explicit_times.target) + config_bundle = replace(config_bundle, training=config.training) super().__init__( model=model, - config=config, - graph_data=graph_data, - statistics=statistics, - statistics_tendencies=statistics_tendencies, - data_indices=data_indices, + config_bundle=config_bundle, + runtime_artifacts=runtime_artifacts, **kwargs, ) - self.boundary_times = config.training.explicit_times.input - self.interp_times = config.training.explicit_times.target + self.boundary_times = self.config.training.explicit_times.input + self.interp_times = self.config.training.explicit_times.target self.n_step_output = len(self.interp_times) sorted_indices = sorted(set(self.boundary_times + self.interp_times)) self.imap = {data_index: batch_index for batch_index, data_index in enumerate(sorted_indices)} diff --git a/training/src/anemoi/training/train/tasks/rollout.py b/training/src/anemoi/training/train/tasks/rollout.py index 3c80d3b0b4..09ad4c2b71 100644 --- a/training/src/anemoi/training/train/tasks/rollout.py +++ b/training/src/anemoi/training/train/tasks/rollout.py @@ -17,18 +17,15 @@ import torch -from anemoi.models.data_indices.collection import IndexCollection from anemoi.training.diagnostics.callbacks.plot_adapter import ForecasterPlotAdapter from anemoi.training.train.tasks.base import BaseGraphModule if TYPE_CHECKING: from collections.abc import Generator - from torch_geometric.data import HeteroData - - from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.interface import ModelInterface - from anemoi.training.schemas.base_schema import BaseSchema + from anemoi.training.config_bundle import TaskConfigBundle + from anemoi.training.runtime import TaskRuntimeArtifacts LOGGER = logging.getLogger(__name__) @@ -41,11 +38,8 @@ def __init__( self, *, model: ModelInterface, - config: BaseSchema, - graph_data: HeteroData, - statistics: dict, - statistics_tendencies: dict, - data_indices: dict[str, IndexCollection], + config_bundle: TaskConfigBundle, + runtime_artifacts: TaskRuntimeArtifacts, **kwargs, ) -> None: """Initialize graph neural network forecaster. @@ -53,26 +47,20 @@ def __init__( Parameters ---------- model : ModelInterface - config : DictConfig - Job configuration - graph_data : HeteroData - Graph object representing the graph data - statistics : dict - Statistics of the training data - data_indices : dict[str, IndexCollection] - Indices of the training data, + config_bundle : TaskConfigBundle + Parts of the config used by this task. + runtime_artifacts : TaskRuntimeArtifacts + Data prepared by the trainer for this task. """ super().__init__( model=model, - config=config, - graph_data=graph_data, - statistics=statistics, - statistics_tendencies=statistics_tendencies, - data_indices=data_indices, + config_bundle=config_bundle, + runtime_artifacts=runtime_artifacts, **kwargs, ) + config = self.config self.rollout = config.training.rollout.start self.rollout_epoch_increment = config.training.rollout.epoch_increment self.rollout_max = config.training.rollout.max @@ -82,6 +70,7 @@ def __init__( LOGGER.debug("Rollout max : %d", self.rollout_max) self._plot_adapter = ForecasterPlotAdapter(self) + self.fill_metadata(self.metadata) def _advance_dataset_input( self, diff --git a/training/src/anemoi/training/train/train.py b/training/src/anemoi/training/train/train.py index c1617231c2..508b73bdbe 100644 --- a/training/src/anemoi/training/train/train.py +++ b/training/src/anemoi/training/train/train.py @@ -30,10 +30,14 @@ from torch_geometric.data import HeteroData from anemoi.models.utils.compile import mark_for_compilation +from anemoi.training.config_bundle import ModelConfigBundle +from anemoi.training.config_bundle import TaskConfigBundle from anemoi.training.data.datamodule import AnemoiDatasetsDataModule from anemoi.training.diagnostics.callbacks import get_callbacks from anemoi.training.diagnostics.logger import get_mlflow_logger from anemoi.training.diagnostics.logger import get_wandb_logger +from anemoi.training.runtime import ModelRuntimeArtifacts +from anemoi.training.runtime import TaskRuntimeArtifacts from anemoi.training.schemas.base_schema import BaseSchema from anemoi.training.schemas.base_schema import UnvalidatedBaseSchema from anemoi.training.schemas.base_schema import convert_to_omegaconf @@ -244,20 +248,20 @@ def _validate_transfer_learning_datasets( def model(self) -> pl.LightningModule: """Provide the model instance.""" model_task = get_class(self.config.training.model_task) + model_runtime_artifacts = self.runtime_artifacts + task_runtime_artifacts = self.task_runtime_artifacts + model = instantiate( + self.config.model_builder, + config_bundle=self.model_config_bundle, + runtime_artifacts=model_runtime_artifacts, + ) - model = instantiate(self.config.model) - - self.metadata["metadata_inference"]["task"] = model_task.task_type + task_runtime_artifacts.metadata["metadata_inference"]["task"] = model_task.task_type kwargs = { "model": model, - "config": self.config, - "data_indices": self.data_indices, - "graph_data": self.graph_data, - "metadata": self.metadata, - "supporting_arrays": self.supporting_arrays, - "statistics": self.datamodule.statistics, - "statistics_tendencies": self.datamodule.statistics_tendencies, + "config_bundle": self.task_config_bundle, + "runtime_artifacts": task_runtime_artifacts, } model = model_task(**kwargs) # GraphForecaster -> pl.LightningModule @@ -270,9 +274,6 @@ def model(self) -> pl.LightningModule: model = transfer_learning_loading(model, self.last_checkpoint) else: LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - # pop data_indices so that the data indices on the checkpoint do not get overwritten - # by the data indices from the new config - kwargs.pop("data_indices") model = model_task.load_from_checkpoint( self.last_checkpoint, **kwargs, @@ -280,9 +281,9 @@ def model(self) -> pl.LightningModule: weights_only=False, # required for Pytorch Lightning 2.6 ) - model.data_indices = self.data_indices + model.data_indices = task_runtime_artifacts.data_indices # Validate data indices between checkpoint and current config - self._validate_transfer_learning_datasets(model, self.data_indices) + self._validate_transfer_learning_datasets(model, task_runtime_artifacts.data_indices) if hasattr(self.config.training, "submodules_to_freeze"): # Freeze the chosen model weights @@ -392,6 +393,33 @@ def supporting_arrays(self) -> dict: return build_combined_supporting_arrays(self.config, self.graph_data, self.datamodule.supporting_arrays) + @cached_property + def runtime_artifacts(self) -> ModelRuntimeArtifacts: + """Data prepared by the trainer and passed when creating the model.""" + return ModelRuntimeArtifacts( + graph_data=self.graph_data, + statistics=self.datamodule.statistics, + statistics_tendencies=self.datamodule.statistics_tendencies, + data_indices=self.data_indices, + metadata=self.metadata, + supporting_arrays=self.supporting_arrays, + ) + + @cached_property + def task_runtime_artifacts(self) -> TaskRuntimeArtifacts: + """Data prepared by the trainer and passed to the task.""" + return self.runtime_artifacts.to_task_runtime_artifacts() + + @cached_property + def model_config_bundle(self) -> ModelConfigBundle: + """Parts of the config used to create the model.""" + return ModelConfigBundle.from_root_config(self.config) + + @cached_property + def task_config_bundle(self) -> TaskConfigBundle: + """Parts of the config used by the training task.""" + return TaskConfigBundle.from_root_config(self.config) + @cached_property def _logger_kwargs(self) -> dict: """Shared keyword arguments for all loggers.""" diff --git a/training/src/anemoi/training/utils/supporting_arrays.py b/training/src/anemoi/training/utils/supporting_arrays.py index 62674927cb..f9c76529a2 100644 --- a/training/src/anemoi/training/utils/supporting_arrays.py +++ b/training/src/anemoi/training/utils/supporting_arrays.py @@ -7,13 +7,20 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from typing import Any + from hydra.utils import instantiate +from torch_geometric.data import HeteroData + +from anemoi.models.utils.config import get_multiple_datasets_config -def build_combined_supporting_arrays(config, graph_data: dict, supporting_arrays: dict) -> dict: +def build_combined_supporting_arrays(config: Any, graph_data: HeteroData, supporting_arrays: dict) -> dict: """Merge output-mask supporting arrays into supporting_arrays.""" - combined = supporting_arrays.copy() - for name, data in graph_data.items(): - mask = instantiate(config.model.output_mask, graph_data=data) + combined = {name: arrays.copy() for name, arrays in supporting_arrays.items()} + dataset_names = get_multiple_datasets_config(config.data).keys() + for name in dataset_names: + combined.setdefault(name, {}) + mask = instantiate(config.model.output_mask, nodes=graph_data[name]) combined[name].update(mask.supporting_arrays) return combined diff --git a/training/tests/integration/aicon/test_cicd_aicon_04_icon-dream_medium.py b/training/tests/integration/aicon/test_cicd_aicon_04_icon-dream_medium.py index d65f7864be..56d6bb5f3a 100644 --- a/training/tests/integration/aicon/test_cicd_aicon_04_icon-dream_medium.py +++ b/training/tests/integration/aicon/test_cicd_aicon_04_icon-dream_medium.py @@ -138,11 +138,11 @@ def test_aicon_metadata(aicon_config_with_grid: DictConfig) -> None: # Monitor path and setting of num_chunks assert ( - trainer.model.model.model.encoder[dataset_name].proc.num_chunks + trainer.model.model.backbone.encoder[dataset_name].proc.num_chunks == aicon_config_with_grid.model.encoder.num_chunks ) assert ( - trainer.model.model.model.decoder[dataset_name].proc.num_chunks + trainer.model.model.backbone.decoder[dataset_name].proc.num_chunks == aicon_config_with_grid.model.decoder.num_chunks ) diff --git a/training/tests/unit/diagnostics/test_plotting_callbacks.py b/training/tests/unit/diagnostics/test_plotting_callbacks.py index 43e9bc9af8..fdb1662ace 100644 --- a/training/tests/unit/diagnostics/test_plotting_callbacks.py +++ b/training/tests/unit/diagnostics/test_plotting_callbacks.py @@ -18,6 +18,7 @@ import omegaconf import pytest import torch +from torch_geometric.data import HeteroData from anemoi.training.diagnostics.callbacks.plot import GraphTrainableFeaturesPlot from anemoi.training.diagnostics.callbacks.plot import PlotHistogram @@ -260,13 +261,9 @@ def _make_pl_module_forecaster( data_indices.model.output.name_to_index = {"a": 0, "b": 1} pl_module.data_indices = {"data": data_indices} # Latlons for graph (radians), converted to deg in process - pl_module.model.model._graph_data = { - "data": MagicMock(), - } - pl_module.model.model._graph_data["data"].__getitem__ = lambda _self, _k: MagicMock() - graph_data = pl_module.model.model._graph_data["data"] - pl_module.model.model._graph_name_data = "x" - graph_data.__getitem__ = lambda k: torch.zeros(nlatlon, 2) if k == pl_module.model.model._graph_name_data else None + graph_data = HeteroData() + graph_data["data"].x = torch.zeros(nlatlon, 2) + pl_module.model.backbone._graph_data = graph_data # output_mask equal to identity pl_module.output_mask = {"data": MagicMock()} pl_module.output_mask["data"].apply.side_effect = lambda x, **_kwargs: x @@ -324,9 +321,9 @@ def _make_pl_module_interpolator(*, output_times=2, nlatlon=50) -> MagicMock: data_indices.data.output.full = slice(None) data_indices.model.output.name_to_index = {"a": 0, "b": 1} pl_module.data_indices = {"data": data_indices} - pl_module.model.model._graph_data = {"data": MagicMock()} - pl_module.model.model._graph_data["data"].__getitem__ = lambda _k: torch.zeros(nlatlon, 2) - pl_module.model.model._graph_name_data = "x" + graph_data = HeteroData() + graph_data["data"].x = torch.zeros(nlatlon, 2) + pl_module.model.backbone._graph_data = graph_data pl_module.output_mask = {"data": MagicMock()} pl_module.output_mask["data"].apply.side_effect = lambda x, **_kwargs: x diff --git a/training/tests/unit/diagnostics/test_plotting_ens_callbacks.py b/training/tests/unit/diagnostics/test_plotting_ens_callbacks.py index d8f18fef1f..390ccc4122 100644 --- a/training/tests/unit/diagnostics/test_plotting_ens_callbacks.py +++ b/training/tests/unit/diagnostics/test_plotting_ens_callbacks.py @@ -13,6 +13,7 @@ import omegaconf import torch +from torch_geometric.data import HeteroData from anemoi.training.diagnostics.callbacks.plot_ens import EnsemblePlotMixin from anemoi.training.diagnostics.callbacks.plot_ens import PlotEnsSample @@ -91,11 +92,9 @@ def test_ensemble_plot_mixin_process(): pl_module.plot_adapter.output_times = 3 pl_module.plot_adapter.get_total_plot_targets.return_value = 3 pl_module.plot_adapter.prepare_plot_output_tensor.side_effect = lambda x: x - pl_module.model.model._graph_name_data = "x" - pl_module.model.model._graph_data = {dataset_name: MagicMock()} - graph_node = MagicMock() - graph_node.x = torch.randn(100, 2) - pl_module.model.model._graph_data[dataset_name].__getitem__ = lambda _self, k: graph_node if k == "x" else None + graph_data = HeteroData() + graph_data[dataset_name].x = torch.randn(100, 2) + pl_module.model.backbone._graph_data = graph_data # data_indices: dict[dataset_name -> IndexCollection] data_indices = MagicMock() diff --git a/training/tests/unit/train/test_builder.py b/training/tests/unit/train/test_builder.py new file mode 100644 index 0000000000..b2d31288ee --- /dev/null +++ b/training/tests/unit/train/test_builder.py @@ -0,0 +1,194 @@ +# (C) Copyright 2026 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING +from typing import Any + +import torch +from omegaconf import OmegaConf +from torch_geometric.data import HeteroData + +from anemoi.training.builder import ModelRuntimeArtifacts +from anemoi.training.builder import build_anemoi_model +from anemoi.training.builder import build_direct_model +from anemoi.training.config_bundle import ModelConfigBundle +from anemoi.training.config_bundle import TaskConfigBundle +from anemoi.training.runtime import TaskRuntimeArtifacts +from anemoi.training.train.train import AnemoiTrainer + +if TYPE_CHECKING: + import pytest + + +class FakeAnemoiModel: + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + + +class DummyTrainer(AnemoiTrainer): + @cached_property + def profiler(self) -> None: + return None + + +def test_build_anemoi_model_uses_injected_runtime_artifacts( + monkeypatch: pytest.MonkeyPatch, +) -> None: + graph = HeteroData() + graph["data"].x = torch.zeros(1, 1) + + def _fake_instantiate(config: Any, **kwargs: Any) -> Any: + if getattr(config, "_target_", None) == "tests.FakeAnemoiModel": + return FakeAnemoiModel(**kwargs) + return config + + monkeypatch.setattr("anemoi.training.builder.instantiate", _fake_instantiate) + + statistics = {"data": {}} + data_indices = {"data": object()} + metadata = {"metadata_inference": {"data": {"timesteps": {"relative_date_indices_training": [0]}}}} + supporting_arrays = {"data": {}} + runtime_artifacts = ModelRuntimeArtifacts( + graph_data=graph, + statistics=statistics, + statistics_tendencies=None, + data_indices=data_indices, + metadata=metadata, + supporting_arrays=supporting_arrays, + ) + + model = build_anemoi_model( + config_bundle=ModelConfigBundle( + training=OmegaConf.create({"multistep_input": 1, "multistep_output": 1}), + data=OmegaConf.create({"processors": {}}), + model=OmegaConf.create( + { + "wrapper": {"_target_": "tests.FakeAnemoiModel"}, + "backbone": {"_target_": "unused"}, + "keep_batch_sharded": False, + "output_mask": {"_target_": "anemoi.training.utils.masks.NoOutputMask"}, + }, + ), + ), + runtime_artifacts=runtime_artifacts, + ) + + assert model.kwargs["graph_data"] is graph + assert model.kwargs["statistics"] is statistics + assert model.kwargs["data_indices"] is data_indices + assert model.kwargs["metadata"] is metadata + assert model.kwargs["supporting_arrays"] == supporting_arrays + + +def test_trainer_model_passes_runtime_artifacts_to_model_instantiation( + monkeypatch: pytest.MonkeyPatch, +) -> None: + graph = HeteroData() + graph["data"].x = torch.zeros(1, 1) + + captured_kwargs = {} + fake_model = object() + + def _capture_instantiate(config: Any, **kwargs: Any) -> Any: + captured_kwargs["target"] = config._target_ + captured_kwargs["kwargs"] = kwargs + return fake_model + + class FakeTask: + task_type = "forecaster" + + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + + monkeypatch.setattr("anemoi.training.train.train.instantiate", _capture_instantiate) + monkeypatch.setattr("anemoi.training.train.train.get_class", lambda _path: FakeTask) + + trainer = DummyTrainer.__new__(DummyTrainer) + runtime_artifacts = ModelRuntimeArtifacts( + graph_data=graph, + statistics={"data": {}}, + statistics_tendencies=None, + data_indices={"data": object()}, + metadata={"metadata_inference": {}}, + supporting_arrays={"data": {}}, + ) + trainer.config = OmegaConf.create( + { + "training": { + "model_task": "unused", + "transfer_learning": False, + }, + "data": {}, + "system": {"hardware": {}}, + "dataloader": {}, + "graph": {}, + "model_builder": { + "_target_": "anemoi.training.builder.build_direct_model", + }, + "model": { + "_target_": "anemoi.models.models.naive.NaiveModel", + }, + }, + ) + trainer.runtime_artifacts = runtime_artifacts + trainer.load_weights_only = False + + task = trainer.model + + assert captured_kwargs["target"] == "anemoi.training.builder.build_direct_model" + assert captured_kwargs["kwargs"]["runtime_artifacts"] is runtime_artifacts + assert isinstance(captured_kwargs["kwargs"]["config_bundle"], ModelConfigBundle) + assert task.kwargs["model"] is fake_model + assert isinstance(task.kwargs["runtime_artifacts"], TaskRuntimeArtifacts) + assert isinstance(task.kwargs["config_bundle"], TaskConfigBundle) + assert task.kwargs["runtime_artifacts"].graph_data is runtime_artifacts.graph_data + assert task.kwargs["runtime_artifacts"].data_indices is runtime_artifacts.data_indices + assert runtime_artifacts.metadata["metadata_inference"]["task"] == "forecaster" + + +def test_build_direct_model_uses_model_spec_from_config_bundle( + monkeypatch: pytest.MonkeyPatch, +) -> None: + graph = HeteroData() + graph["data"].x = torch.zeros(1, 1) + + captured_kwargs = {} + fake_model = object() + + def _capture_instantiate(config: Any, **kwargs: Any) -> Any: + captured_kwargs["target"] = config._target_ + captured_kwargs["kwargs"] = kwargs + return fake_model + + monkeypatch.setattr("anemoi.training.builder.instantiate", _capture_instantiate) + + runtime_artifacts = ModelRuntimeArtifacts( + graph_data=graph, + statistics={"data": {}}, + statistics_tendencies=None, + data_indices={"data": object()}, + metadata={"metadata_inference": {}}, + supporting_arrays={"data": {}}, + ) + + model = build_direct_model( + config_bundle=ModelConfigBundle( + training=OmegaConf.create({"multistep_input": 1, "multistep_output": 1}), + data=OmegaConf.create({}), + model=OmegaConf.create({"_target_": "anemoi.models.models.naive.NaiveModel", "n_input": 1}), + ), + runtime_artifacts=runtime_artifacts, + ) + + assert model is fake_model + assert captured_kwargs["target"] == "anemoi.models.models.naive.NaiveModel" + assert captured_kwargs["kwargs"]["runtime_artifacts"] is runtime_artifacts diff --git a/training/tests/unit/train/test_tasks.py b/training/tests/unit/train/test_tasks.py index e003f80071..aa716ec651 100644 --- a/training/tests/unit/train/test_tasks.py +++ b/training/tests/unit/train/test_tasks.py @@ -1,3 +1,4 @@ +import inspect from typing import Any from unittest.mock import MagicMock @@ -19,6 +20,7 @@ from anemoi.training.losses.multiscale import MultiscaleLossWrapper from anemoi.training.train.tasks.base import BaseGraphModule from anemoi.training.train.tasks.diffusionforecaster import GraphDiffusionForecaster +from anemoi.training.train.tasks.diffusionforecaster import GraphDiffusionTendForecaster from anemoi.training.train.tasks.ensforecaster import GraphEnsForecaster from anemoi.training.train.tasks.forecaster import GraphForecaster from anemoi.training.train.tasks.interpolator import GraphMultiOutInterpolator @@ -149,6 +151,18 @@ def fwd_with_preconditioning( y_noised = y_noised.unsqueeze(1) return y_noised + 0.1 * pred + def get_diffusion_parameters(self) -> tuple[float, float, float]: + return self.sigma_max, self.sigma_min, self.sigma_data + + def forward_with_preconditioning( + self, + x: torch.Tensor | dict[str, torch.Tensor], + y_noised: torch.Tensor | dict[str, torch.Tensor], + sigma: torch.Tensor | dict[str, torch.Tensor], + **kwargs, + ) -> torch.Tensor: + return self.fwd_with_preconditioning(x, y_noised, sigma, **kwargs) + def _make_minimal_index_collection(name_to_index: dict[str, int]) -> IndexCollection: cfg = DictConfig({"forcing": [], "diagnostic": [], "target": []}) @@ -273,24 +287,20 @@ def test_graphforecaster(monkeypatch: pytest.MonkeyPatch) -> None: _CFG_DIFFUSION = DictConfig( { "training": {"multistep_input": 1, "multistep_output": 1}, - "model": {"model": {"diffusion": {"rho": 7.0}}}, + "model": {"backbone": {"diffusion": {"rho": 7.0}}}, }, ) def test_graphdiffusionforecaster() -> None: - class DummyDiffusion: - def __init__(self, model: DummyDiffusionModel) -> None: - self.model = model - data_indices = _data_indices_single() forecaster = GraphDiffusionForecaster.__new__(GraphDiffusionForecaster) pl.LightningModule.__init__(forecaster) _set_base_task_attrs(forecaster, data_indices=data_indices, config=_CFG_DIFFUSION) - forecaster.model = DummyDiffusion( - DummyDiffusionModel(num_output_variables=len(next(iter(data_indices.values())).model.output)), + forecaster.model = DummyDiffusionModel( + num_output_variables=len(next(iter(data_indices.values())).model.output), ) - forecaster.rho = _CFG_DIFFUSION.model.model.diffusion.rho + forecaster.rho = _CFG_DIFFUSION.model.backbone.diffusion.rho forecaster.is_first_step = False forecaster.updating_scalars = {} forecaster.target_dataset_names = forecaster.dataset_names @@ -310,6 +320,13 @@ def __init__(self, model: DummyDiffusionModel) -> None: assert y_pred.shape == (b, 1, e, g, v) +def test_forecaster_task_classes_are_concrete() -> None: + assert not inspect.isabstract(GraphForecaster) + assert not inspect.isabstract(GraphDiffusionForecaster) + assert not inspect.isabstract(GraphDiffusionTendForecaster) + assert not inspect.isabstract(GraphEnsForecaster) + + def test_base_compute_loss_forwards_standard_loss_kwargs() -> None: module = MagicMock(spec=BaseGraphModule) loss = CaptureLoss() diff --git a/training/tests/unit/utils/test_supporting_arrays.py b/training/tests/unit/utils/test_supporting_arrays.py new file mode 100644 index 0000000000..6625e486b4 --- /dev/null +++ b/training/tests/unit/utils/test_supporting_arrays.py @@ -0,0 +1,40 @@ +# (C) Copyright 2026 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import torch +from omegaconf import OmegaConf +from torch_geometric.data import HeteroData + +from anemoi.training.utils.supporting_arrays import build_combined_supporting_arrays + + +def test_build_combined_supporting_arrays_uses_dataset_nodes_only() -> None: + graph = HeteroData() + graph["data"]["output_nodes"] = torch.tensor([True, False, True]) + graph["hidden"].x = torch.randn(2, 1) + + config = OmegaConf.create( + { + "data": {"processors": {}}, + "model": { + "output_mask": { + "_target_": "anemoi.training.utils.masks.Boolean1DMask", + "attribute_name": "output_nodes", + }, + }, + }, + ) + supporting_arrays = {"data": {"orography": [1, 2, 3]}} + + combined = build_combined_supporting_arrays(config, graph, supporting_arrays) + + assert "output_mask" in combined["data"] + assert combined["data"]["output_mask"].tolist() == [True, False, True] + assert "hidden" not in combined + assert supporting_arrays == {"data": {"orography": [1, 2, 3]}}