Skip to content

Commit

Permalink
docs: fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
JesperDramsch committed Sep 16, 2024
1 parent b4cbab8 commit 7728456
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 14 deletions.
18 changes: 14 additions & 4 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: i
assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive"
self.num_nearest_neighbours = num_nearest_neighbours

def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray):
def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray) -> np.ndarray:
"""Compute the adjacency matrix for the KNN method.
Parameters
Expand All @@ -162,6 +162,11 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra
The source nodes.
target_nodes : np.ndarray
The target nodes.
Returns
-------
np.ndarray
The adjacency matrix.
"""
assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder"
LOGGER.info(
Expand Down Expand Up @@ -203,13 +208,13 @@ class CutOffEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, cutoff_factor: float):
def __init__(self, source_name: str, target_name: str, cutoff_factor: float) -> None:
super().__init__(source_name, target_name)
assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float"
assert cutoff_factor > 0, "Cutoff factor must be positive"
self.cutoff_factor = cutoff_factor

def get_cutoff_radius(self, graph: HeteroData, mask_attr: torch.Tensor | None = None):
def get_cutoff_radius(self, graph: HeteroData, mask_attr: torch.Tensor | None = None) -> float:
"""Compute the cut-off radius.
The cut-off radius is computed as the product of the target nodes
Expand Down Expand Up @@ -238,7 +243,7 @@ def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage
self.radius = self.get_cutoff_radius(graph)
return super().prepare_node_data(graph)

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> np.ndarray:
"""Get the adjacency matrix for the cut-off method.
Parameters
Expand All @@ -247,6 +252,11 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
The source nodes.
target_nodes : NodeStorage
The target nodes.
Returns
-------
np.ndarray
The adjacency matrix.
"""
LOGGER.info(
"Using CutOff-Edges (with radius = %.1f km) between %s and %s.",
Expand Down
15 changes: 15 additions & 0 deletions src/anemoi/graphs/generate/hexagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def add_nodes_for_resolution(
The H3 refinement level. It can be an integer from 0 to 15.
area_kwargs: dict
Additional arguments to pass to the get_nodes_at_resolution function.
Returns
-------
graph : networkx.Graph
The graph with the added nodes.
"""

nodes = get_nodes_at_resolution(resolution, **area_kwargs)
Expand Down Expand Up @@ -173,6 +178,11 @@ def add_edges_to_children(
depth_children : Optional[int], optional
The number of resolution levels to consider for the connections of children. Defaults to 1, which includes
connections up to the next resolution level, by default None
Returns
-------
graph : nx.Graph
graph with the added edges
"""
if depth_children is None:
depth_children = len(refinement_levels)
Expand Down Expand Up @@ -215,6 +225,11 @@ def add_edge(
The H3 index of the tail of the edge.
target_node_h3_idx : str
The H3 index of the head of the edge.
Returns
-------
graph : networkx.Graph
The graph with the added edge.
"""
if not graph.has_node(source_node_h3_idx) or not graph.has_node(target_node_h3_idx):
return graph
Expand Down
11 changes: 1 addition & 10 deletions src/anemoi/graphs/generate/icosahedral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def create_icosahedral_nodes(
----------
resolutions : list[int]
Levels of mesh resolution to consider.
aoi_mask_builder : KNNAreaMaskBuilder
KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None.
Returns
-------
Expand Down Expand Up @@ -86,10 +84,6 @@ def add_edges_to_nx_graph(
Levels of mesh refinement levels to consider.
x_hops : int, optional
Number of hops between 2 nodes to consider them neighbours, by default 1.
aoi_mask_builder : KNNAreaMaskBuilder
NearestNeighbors with the cloud of points to limit the mesh area, by default None.
margin_radius_km : float, optional
Margin radius in km to consider when creating the processor mesh, by default 0.0.
Returns
-------
Expand Down Expand Up @@ -137,9 +131,6 @@ def get_neighbours_within_hops(tri_mesh: trimesh.Trimesh, x_hops: int) -> dict[i
The mesh to consider.
x_hops : int
Number of hops between 2 nodes to consider them neighbours.
valid_nodes : list[int], optional
List of valid nodes to consider, by default None. It is useful to consider only a subset of the nodes to save
computation time.
Returns
-------
Expand Down Expand Up @@ -178,7 +169,7 @@ def add_neigbours_edges(
A 2D array of shape (num_vertices, 2) with the planar coordinates of the mesh, in radians.
node_idx : int
The node considered.
neighbours : list[int]
neighbour_indices : list[int]
The neighbours of the node.
self_loops : bool, optional
Whether is supported to add self-loops, by default False.
Expand Down
12 changes: 12 additions & 0 deletions src/anemoi/graphs/nodes/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch.
Graph.
nodes_name : str
Name of the nodes.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.
Returns
-------
Expand Down Expand Up @@ -70,6 +74,10 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
----------
nodes : NodeStorage
Nodes of the graph.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.
Returns
-------
Expand Down Expand Up @@ -111,6 +119,10 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
----------
nodes : NodeStorage
Nodes of the graph.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ class IcosahedralNodes(BaseNodeBuilder, ABC):
----------
resolution : list[int] | int
Refinement level of the mesh.
name : str
The name of the nodes.
"""

def __init__(
Expand Down

0 comments on commit 7728456

Please sign in to comment.