diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 389faec..7f34390 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -153,7 +153,7 @@ def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: i assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" self.num_nearest_neighbours = num_nearest_neighbours - def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray): + def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray) -> np.ndarray: """Compute the adjacency matrix for the KNN method. Parameters @@ -162,6 +162,11 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra The source nodes. target_nodes : np.ndarray The target nodes. + + Returns + ------- + np.ndarray + The adjacency matrix. """ assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" LOGGER.info( @@ -203,13 +208,13 @@ class CutOffEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, cutoff_factor: float): + def __init__(self, source_name: str, target_name: str, cutoff_factor: float) -> None: super().__init__(source_name, target_name) assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float" assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor - def get_cutoff_radius(self, graph: HeteroData, mask_attr: torch.Tensor | None = None): + def get_cutoff_radius(self, graph: HeteroData, mask_attr: torch.Tensor | None = None) -> float: """Compute the cut-off radius. The cut-off radius is computed as the product of the target nodes @@ -238,7 +243,7 @@ def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage self.radius = self.get_cutoff_radius(graph) return super().prepare_node_data(graph) - def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> np.ndarray: """Get the adjacency matrix for the cut-off method. Parameters @@ -247,6 +252,11 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor The source nodes. target_nodes : NodeStorage The target nodes. + + Returns + ------- + np.ndarray + The adjacency matrix. """ LOGGER.info( "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 9e1bf29..2a9cfe3 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -62,6 +62,11 @@ def add_nodes_for_resolution( The H3 refinement level. It can be an integer from 0 to 15. area_kwargs: dict Additional arguments to pass to the get_nodes_at_resolution function. + + Returns + ------- + graph : networkx.Graph + The graph with the added nodes. """ nodes = get_nodes_at_resolution(resolution, **area_kwargs) @@ -173,6 +178,11 @@ def add_edges_to_children( depth_children : Optional[int], optional The number of resolution levels to consider for the connections of children. Defaults to 1, which includes connections up to the next resolution level, by default None + + Returns + ------- + graph : nx.Graph + graph with the added edges """ if depth_children is None: depth_children = len(refinement_levels) @@ -215,6 +225,11 @@ def add_edge( The H3 index of the tail of the edge. target_node_h3_idx : str The H3 index of the head of the edge. + + Returns + ------- + graph : networkx.Graph + The graph with the added edge. """ if not graph.has_node(source_node_h3_idx) or not graph.has_node(target_node_h3_idx): return graph diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index ed88621..357676a 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -22,8 +22,6 @@ def create_icosahedral_nodes( ---------- resolutions : list[int] Levels of mesh resolution to consider. - aoi_mask_builder : KNNAreaMaskBuilder - KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. Returns ------- @@ -86,10 +84,6 @@ def add_edges_to_nx_graph( Levels of mesh refinement levels to consider. x_hops : int, optional Number of hops between 2 nodes to consider them neighbours, by default 1. - aoi_mask_builder : KNNAreaMaskBuilder - NearestNeighbors with the cloud of points to limit the mesh area, by default None. - margin_radius_km : float, optional - Margin radius in km to consider when creating the processor mesh, by default 0.0. Returns ------- @@ -137,9 +131,6 @@ def get_neighbours_within_hops(tri_mesh: trimesh.Trimesh, x_hops: int) -> dict[i The mesh to consider. x_hops : int Number of hops between 2 nodes to consider them neighbours. - valid_nodes : list[int], optional - List of valid nodes to consider, by default None. It is useful to consider only a subset of the nodes to save - computation time. Returns ------- @@ -178,7 +169,7 @@ def add_neigbours_edges( A 2D array of shape (num_vertices, 2) with the planar coordinates of the mesh, in radians. node_idx : int The node considered. - neighbours : list[int] + neighbour_indices : list[int] The neighbours of the node. self_loops : bool, optional Whether is supported to add self-loops, by default False. diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index 11009d5..16a577d 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -43,6 +43,10 @@ def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch. Graph. nodes_name : str Name of the nodes. + args : tuple + Additional arguments. + kwargs : dict + Additional keyword arguments. Returns ------- @@ -70,6 +74,10 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: ---------- nodes : NodeStorage Nodes of the graph. + args : tuple + Additional arguments. + kwargs : dict + Additional keyword arguments. Returns ------- @@ -111,6 +119,10 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: ---------- nodes : NodeStorage Nodes of the graph. + args : tuple + Additional arguments. + kwargs : dict + Additional keyword arguments. Returns ------- diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 54753c4..12c818c 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -204,6 +204,8 @@ class IcosahedralNodes(BaseNodeBuilder, ABC): ---------- resolution : list[int] | int Refinement level of the mesh. + name : str + The name of the nodes. """ def __init__(