Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support node masking in edge builder #50

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ def generate_graph(self) -> HeteroData:
)

for edges_cfg in self.config.get("edges", {}):
graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph(
graph, edges_cfg.get("attributes", {})
)
graph = instantiate(
edges_cfg.edge_builder,
edges_cfg.source_name,
edges_cfg.target_name,
source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None),
target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None),
).update_graph(graph, edges_cfg.get("attributes", {}))

return graph

Expand Down
108 changes: 92 additions & 16 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from scipy.sparse import coo_matrix
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage
Expand All @@ -28,9 +29,17 @@
class BaseEdgeBuilder(ABC):
"""Base class for edge builders."""

def __init__(self, source_name: str, target_name: str):
def __init__(
self,
source_name: str,
target_name: str,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
self.source_name = source_name
self.target_name = target_name
self.source_mask_attr_name = source_mask_attr_name
self.target_mask_attr_name = target_mask_attr_name

@property
def name(self) -> tuple[str, str, str]:
Expand Down Expand Up @@ -125,7 +134,42 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
return graph


class KNNEdges(BaseEdgeBuilder):
class NodeMaskingMixin:
"""Mixin class for masking source/target nodes when building edges."""

def get_node_coordinates(
self, source_nodes: NodeStorage, target_nodes: NodeStorage
) -> tuple[np.ndarray, np.ndarray]:
"""Get the node coordinates."""
source_coords, target_coords = source_nodes.x.numpy(), target_nodes.x.numpy()

if self.source_mask_attr_name is not None:
source_coords = source_coords[source_nodes[self.source_mask_attr_name].squeeze()]

if self.target_mask_attr_name is not None:
target_coords = target_coords[target_nodes[self.target_mask_attr_name].squeeze()]

return source_coords, target_coords

def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.target_mask_attr_name is not None:
target_mask = target_nodes[self.target_mask_attr_name].squeeze()
target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0]))
adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row)

if self.source_mask_attr_name is not None:
source_mask = source_nodes[self.source_mask_attr_name].squeeze()
source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0]))
adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col)

if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None:
true_shape = target_nodes.x.shape[0], source_nodes.x.shape[0]
adj_matrix = coo_matrix((adj_matrix.data, (adj_matrix.row, adj_matrix.col)), shape=true_shape)

return adj_matrix


class KNNEdges(BaseEdgeBuilder, NodeMaskingMixin):
"""Computes KNN based edges and adds them to the graph.

Attributes
Expand All @@ -136,6 +180,10 @@ class KNNEdges(BaseEdgeBuilder):
The name of the target nodes.
num_nearest_neighbours : int
Number of nearest neighbours.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.

Methods
-------
Expand All @@ -147,22 +195,30 @@ class KNNEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int):
super().__init__(source_name, target_name)
def __init__(
self,
source_name: str,
target_name: str,
num_nearest_neighbours: int,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)
assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer"
assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive"
self.num_nearest_neighbours = num_nearest_neighbours

def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray):
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Compute the adjacency matrix for the KNN method.

Parameters
----------
source_nodes : np.ndarray
source_nodes : NodeStorage
The source nodes.
target_nodes : np.ndarray
target_nodes : NodeStorage
The target nodes.
"""
source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes)
assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder"
LOGGER.info(
"Using KNN-Edges (with %d nearest neighbours) between %s and %s.",
Expand All @@ -172,16 +228,20 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(source_nodes.x.numpy())
nearest_neighbour.fit(source_coords)
adj_matrix = nearest_neighbour.kneighbors_graph(
target_nodes.x.numpy(),
target_coords,
n_neighbors=self.num_nearest_neighbours,
mode="distance",
).tocoo()

# Post-process the adjacency matrix. Add masked nodes.
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)

return adj_matrix


class CutOffEdges(BaseEdgeBuilder):
class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin):
"""Computes cut-off based edges and adds them to the graph.

Attributes
Expand All @@ -192,6 +252,10 @@ class CutOffEdges(BaseEdgeBuilder):
The name of the target nodes.
cutoff_factor : float
Factor to multiply the grid reference distance to get the cut-off radius.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.

Methods
-------
Expand All @@ -203,8 +267,15 @@ class CutOffEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, cutoff_factor: float):
super().__init__(source_name, target_name)
def __init__(
self,
source_name: str,
target_name: str,
cutoff_factor: float,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)
assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float"
assert cutoff_factor > 0, "Cutoff factor must be positive"
self.cutoff_factor = cutoff_factor
Expand Down Expand Up @@ -248,6 +319,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
target_nodes : NodeStorage
The target nodes.
"""
source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes)
LOGGER.info(
"Using CutOff-Edges (with radius = %.1f km) between %s and %s.",
self.radius * EARTH_RADIUS,
Expand All @@ -256,8 +328,12 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(source_nodes.x)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo()
nearest_neighbour.fit(source_coords)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_coords, radius=self.radius).tocoo()

# Post-process the adjacency matrix. Add masked nodes.
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)

return adj_matrix


Expand Down Expand Up @@ -286,7 +362,7 @@ class MultiScaleEdges(BaseEdgeBuilder):

VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes]

def __init__(self, source_name: str, target_name: str, x_hops: int):
def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs):
super().__init__(source_name, target_name)
assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same."
assert isinstance(x_hops, int), "Number of x_hops must be an integer"
Expand All @@ -299,7 +375,7 @@ def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage:
nodes["_nx_graph"],
resolutions=nodes["_resolutions"],
x_hops=self.x_hops,
aoi_mask_builder=nodes.get("_aoi_mask_builder", None),
area_mask_builder=nodes.get("_area_mask_builder", None),
)

return nodes
Expand Down
Loading