diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 51de9a4..17a46f7 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -42,9 +42,13 @@ def generate_graph(self) -> HeteroData: ) for edges_cfg in self.config.get("edges", {}): - graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph( - graph, edges_cfg.get("attributes", {}) - ) + graph = instantiate( + edges_cfg.edge_builder, + edges_cfg.source_name, + edges_cfg.target_name, + source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None), + target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None), + ).update_graph(graph, edges_cfg.get("attributes", {})) return graph diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 79e4fa1..e43d005 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -9,6 +9,7 @@ import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate +from scipy.sparse import coo_matrix from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage @@ -28,9 +29,17 @@ class BaseEdgeBuilder(ABC): """Base class for edge builders.""" - def __init__(self, source_name: str, target_name: str): + def __init__( + self, + source_name: str, + target_name: str, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): self.source_name = source_name self.target_name = target_name + self.source_mask_attr_name = source_mask_attr_name + self.target_mask_attr_name = target_mask_attr_name @property def name(self) -> tuple[str, str, str]: @@ -125,7 +134,42 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) - return graph -class KNNEdges(BaseEdgeBuilder): +class NodeMaskingMixin: + """Mixin class for masking source/target nodes when building edges.""" + + def get_node_coordinates( + self, source_nodes: NodeStorage, target_nodes: NodeStorage + ) -> tuple[np.ndarray, np.ndarray]: + """Get the node coordinates.""" + source_coords, target_coords = source_nodes.x.numpy(), target_nodes.x.numpy() + + if self.source_mask_attr_name is not None: + source_coords = source_coords[source_nodes[self.source_mask_attr_name].squeeze()] + + if self.target_mask_attr_name is not None: + target_coords = target_coords[target_nodes[self.target_mask_attr_name].squeeze()] + + return source_coords, target_coords + + def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage): + if self.target_mask_attr_name is not None: + target_mask = target_nodes[self.target_mask_attr_name].squeeze() + target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0])) + adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row) + + if self.source_mask_attr_name is not None: + source_mask = source_nodes[self.source_mask_attr_name].squeeze() + source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0])) + adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col) + + if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None: + true_shape = target_nodes.x.shape[0], source_nodes.x.shape[0] + adj_matrix = coo_matrix((adj_matrix.data, (adj_matrix.row, adj_matrix.col)), shape=true_shape) + + return adj_matrix + + +class KNNEdges(BaseEdgeBuilder, NodeMaskingMixin): """Computes KNN based edges and adds them to the graph. Attributes @@ -136,6 +180,10 @@ class KNNEdges(BaseEdgeBuilder): The name of the target nodes. num_nearest_neighbours : int Number of nearest neighbours. + source_mask_attr_name : str | None + The name of the source mask attribute to filter edge connections. + target_mask_attr_name : str | None + The name of the target mask attribute to filter edge connections. Methods ------- @@ -147,22 +195,30 @@ class KNNEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): - super().__init__(source_name, target_name) + def __init__( + self, + source_name: str, + target_name: str, + num_nearest_neighbours: int, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" self.num_nearest_neighbours = num_nearest_neighbours - def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray): + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): """Compute the adjacency matrix for the KNN method. Parameters ---------- - source_nodes : np.ndarray + source_nodes : NodeStorage The source nodes. - target_nodes : np.ndarray + target_nodes : NodeStorage The target nodes. """ + source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes) assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" LOGGER.info( "Using KNN-Edges (with %d nearest neighbours) between %s and %s.", @@ -172,16 +228,20 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(source_nodes.x.numpy()) + nearest_neighbour.fit(source_coords) adj_matrix = nearest_neighbour.kneighbors_graph( - target_nodes.x.numpy(), + target_coords, n_neighbors=self.num_nearest_neighbours, mode="distance", ).tocoo() + + # Post-process the adjacency matrix. Add masked nodes. + adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes) + return adj_matrix -class CutOffEdges(BaseEdgeBuilder): +class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin): """Computes cut-off based edges and adds them to the graph. Attributes @@ -192,6 +252,10 @@ class CutOffEdges(BaseEdgeBuilder): The name of the target nodes. cutoff_factor : float Factor to multiply the grid reference distance to get the cut-off radius. + source_mask_attr_name : str | None + The name of the source mask attribute to filter edge connections. + target_mask_attr_name : str | None + The name of the target mask attribute to filter edge connections. Methods ------- @@ -203,8 +267,15 @@ class CutOffEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, cutoff_factor: float): - super().__init__(source_name, target_name) + def __init__( + self, + source_name: str, + target_name: str, + cutoff_factor: float, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float" assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor @@ -248,6 +319,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor target_nodes : NodeStorage The target nodes. """ + source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes) LOGGER.info( "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", self.radius * EARTH_RADIUS, @@ -256,8 +328,12 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(source_nodes.x) - adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo() + nearest_neighbour.fit(source_coords) + adj_matrix = nearest_neighbour.radius_neighbors_graph(target_coords, radius=self.radius).tocoo() + + # Post-process the adjacency matrix. Add masked nodes. + adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes) + return adj_matrix @@ -286,7 +362,7 @@ class MultiScaleEdges(BaseEdgeBuilder): VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes] - def __init__(self, source_name: str, target_name: str, x_hops: int): + def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs): super().__init__(source_name, target_name) assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same." assert isinstance(x_hops, int), "Number of x_hops must be an integer" @@ -299,7 +375,7 @@ def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: nodes["_nx_graph"], resolutions=nodes["_resolutions"], x_hops=self.x_hops, - aoi_mask_builder=nodes.get("_aoi_mask_builder", None), + area_mask_builder=nodes.get("_area_mask_builder", None), ) return nodes