Skip to content

Commit

Permalink
[Feature] Support node masking in edge builder (#50)
Browse files Browse the repository at this point in the history
* feat: support node masking in edge builder

---------

Co-authored-by: Helen Theissen <[email protected]>
  • Loading branch information
JPXKQX and theissenhelen committed Sep 17, 2024
1 parent 4762f3f commit fac6f36
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 19 deletions.
10 changes: 7 additions & 3 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ def generate_graph(self) -> HeteroData:
)

for edges_cfg in self.config.get("edges", {}):
graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph(
graph, edges_cfg.get("attributes", {})
)
graph = instantiate(
edges_cfg.edge_builder,
edges_cfg.source_name,
edges_cfg.target_name,
source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None),
target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None),
).update_graph(graph, edges_cfg.get("attributes", {}))

return graph

Expand Down
108 changes: 92 additions & 16 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from scipy.sparse import coo_matrix
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage
Expand All @@ -28,9 +29,17 @@
class BaseEdgeBuilder(ABC):
"""Base class for edge builders."""

def __init__(self, source_name: str, target_name: str):
def __init__(
self,
source_name: str,
target_name: str,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
self.source_name = source_name
self.target_name = target_name
self.source_mask_attr_name = source_mask_attr_name
self.target_mask_attr_name = target_mask_attr_name

@property
def name(self) -> tuple[str, str, str]:
Expand Down Expand Up @@ -125,7 +134,42 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
return graph


class KNNEdges(BaseEdgeBuilder):
class NodeMaskingMixin:
"""Mixin class for masking source/target nodes when building edges."""

def get_node_coordinates(
self, source_nodes: NodeStorage, target_nodes: NodeStorage
) -> tuple[np.ndarray, np.ndarray]:
"""Get the node coordinates."""
source_coords, target_coords = source_nodes.x.numpy(), target_nodes.x.numpy()

if self.source_mask_attr_name is not None:
source_coords = source_coords[source_nodes[self.source_mask_attr_name].squeeze()]

if self.target_mask_attr_name is not None:
target_coords = target_coords[target_nodes[self.target_mask_attr_name].squeeze()]

return source_coords, target_coords

def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.target_mask_attr_name is not None:
target_mask = target_nodes[self.target_mask_attr_name].squeeze()
target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0]))
adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row)

if self.source_mask_attr_name is not None:
source_mask = source_nodes[self.source_mask_attr_name].squeeze()
source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0]))
adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col)

if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None:
true_shape = target_nodes.x.shape[0], source_nodes.x.shape[0]
adj_matrix = coo_matrix((adj_matrix.data, (adj_matrix.row, adj_matrix.col)), shape=true_shape)

return adj_matrix


class KNNEdges(BaseEdgeBuilder, NodeMaskingMixin):
"""Computes KNN based edges and adds them to the graph.
Attributes
Expand All @@ -136,6 +180,10 @@ class KNNEdges(BaseEdgeBuilder):
The name of the target nodes.
num_nearest_neighbours : int
Number of nearest neighbours.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.
Methods
-------
Expand All @@ -147,22 +195,30 @@ class KNNEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int):
super().__init__(source_name, target_name)
def __init__(
self,
source_name: str,
target_name: str,
num_nearest_neighbours: int,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)
assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer"
assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive"
self.num_nearest_neighbours = num_nearest_neighbours

def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray):
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Compute the adjacency matrix for the KNN method.
Parameters
----------
source_nodes : np.ndarray
source_nodes : NodeStorage
The source nodes.
target_nodes : np.ndarray
target_nodes : NodeStorage
The target nodes.
"""
source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes)
assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder"
LOGGER.info(
"Using KNN-Edges (with %d nearest neighbours) between %s and %s.",
Expand All @@ -172,16 +228,20 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(source_nodes.x.numpy())
nearest_neighbour.fit(source_coords)
adj_matrix = nearest_neighbour.kneighbors_graph(
target_nodes.x.numpy(),
target_coords,
n_neighbors=self.num_nearest_neighbours,
mode="distance",
).tocoo()

# Post-process the adjacency matrix. Add masked nodes.
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)

return adj_matrix


class CutOffEdges(BaseEdgeBuilder):
class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin):
"""Computes cut-off based edges and adds them to the graph.
Attributes
Expand All @@ -192,6 +252,10 @@ class CutOffEdges(BaseEdgeBuilder):
The name of the target nodes.
cutoff_factor : float
Factor to multiply the grid reference distance to get the cut-off radius.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.
Methods
-------
Expand All @@ -203,8 +267,15 @@ class CutOffEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, cutoff_factor: float):
super().__init__(source_name, target_name)
def __init__(
self,
source_name: str,
target_name: str,
cutoff_factor: float,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)
assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float"
assert cutoff_factor > 0, "Cutoff factor must be positive"
self.cutoff_factor = cutoff_factor
Expand Down Expand Up @@ -248,6 +319,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
target_nodes : NodeStorage
The target nodes.
"""
source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes)
LOGGER.info(
"Using CutOff-Edges (with radius = %.1f km) between %s and %s.",
self.radius * EARTH_RADIUS,
Expand All @@ -256,8 +328,12 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(source_nodes.x)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo()
nearest_neighbour.fit(source_coords)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_coords, radius=self.radius).tocoo()

# Post-process the adjacency matrix. Add masked nodes.
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)

return adj_matrix


Expand Down Expand Up @@ -286,7 +362,7 @@ class MultiScaleEdges(BaseEdgeBuilder):

VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes]

def __init__(self, source_name: str, target_name: str, x_hops: int):
def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs):
super().__init__(source_name, target_name)
assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same."
assert isinstance(x_hops, int), "Number of x_hops must be an integer"
Expand All @@ -299,7 +375,7 @@ def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage:
nodes["_nx_graph"],
resolutions=nodes["_resolutions"],
x_hops=self.x_hops,
aoi_mask_builder=nodes.get("_aoi_mask_builder", None),
area_mask_builder=nodes.get("_area_mask_builder", None),
)

return nodes
Expand Down

0 comments on commit fac6f36

Please sign in to comment.