Skip to content

Commit 7c1b28c

Browse files
committed
📝 Added docstrings for StaticGraph
However, there are few methods yet to be documented. Kept them as TODO.
1 parent f0f44bd commit 7c1b28c

File tree

2 files changed

+64
-37
lines changed

2 files changed

+64
-37
lines changed

stgraph/graph/STGraphBase.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,6 @@ class STGraphBase(ABC):
1111
Attributes
1212
----------
1313
14-
_ndata : dict
15-
Dictionary that stores node related data
16-
17-
_forward_graph
18-
The forward graph object used for forward propagation
19-
20-
_backward_graph
21-
The backward graph object used for backward propagation
22-
2314
fwd_row_offset_ptr
2415
Pointer to the forward graphs row offset array
2516

stgraph/graph/static/StaticGraph.py

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,66 @@
1212

1313
from stgraph.graph.static.csr import CSR
1414

15+
1516
class StaticGraph(STGraphBase):
16-
def __init__(self, edge_list, edge_weights, num_nodes):
17+
r"""An abstract base class used to represent static graphs in STGraph.
18+
19+
This abstract class outlines the interface for defining a static graphs
20+
used in STGraph. As of now the static graph is implemented using the
21+
Compressed Sparse Row (CSR) format.
22+
23+
Example
24+
-------
25+
26+
.. code-block:: python
27+
28+
from stgraph.graph import StaticGraph
29+
from stgraph.dataset import HungaryCPDataLoader
30+
31+
hungary = HungaryCPDataLoader()
32+
33+
graph = StaticGraph(
34+
edge_list = hungary.get_edges(),
35+
edge_weights = hungary.get_edge_weights(),
36+
num_nodes = hungary.gdata["num_nodes"]
37+
)
38+
39+
"""
40+
41+
def __init__(self, edge_list, edge_weights, num_nodes):
42+
"""An abstract base class used to represent static graphs in STGraph."""
1743
super().__init__()
1844
self._num_nodes = num_nodes
1945
self._num_edges = len(set(edge_list))
20-
21-
# console.log("Building forward edge list")
46+
2247
self._prepare_edge_lst_fwd(edge_list)
23-
# console.log("Creating forward graph")
24-
self._forward_graph = CSR(self.fwd_edge_list, edge_weights, self._num_nodes, is_edge_reverse=True)
25-
26-
# console.log("Building backward edge list")
48+
self._forward_graph = CSR(
49+
self.fwd_edge_list, edge_weights, self._num_nodes, is_edge_reverse=True
50+
)
51+
2752
self._prepare_edge_lst_bwd(self.fwd_edge_list)
28-
# console.log("Creating backward graph")
2953
self._backward_graph = CSR(self.bwd_edge_list, edge_weights, self._num_nodes)
30-
31-
# console.log("Getting CSR ptrs")
54+
3255
self._get_graph_csr_ptrs()
33-
34-
def _prepare_edge_lst_fwd(self, edge_list):
56+
57+
def _prepare_edge_lst_fwd(self, edge_list):
58+
r"""TODO:"""
3559
edge_list_for_t = edge_list
36-
edge_list_for_t.sort(key = lambda x: (x[1],x[0]))
37-
edge_list_for_t = [(edge_list_for_t[j][0],edge_list_for_t[j][1],j) for j in range(len(edge_list_for_t))]
60+
edge_list_for_t.sort(key=lambda x: (x[1], x[0]))
61+
edge_list_for_t = [
62+
(edge_list_for_t[j][0], edge_list_for_t[j][1], j)
63+
for j in range(len(edge_list_for_t))
64+
]
3865
self.fwd_edge_list = edge_list_for_t
39-
40-
def _prepare_edge_lst_bwd(self, edge_list):
66+
67+
def _prepare_edge_lst_bwd(self, edge_list):
68+
r"""TODO:"""
4169
edge_list_for_t = copy.deepcopy(edge_list)
4270
edge_list_for_t.sort()
4371
self.bwd_edge_list = edge_list_for_t
44-
72+
4573
def _get_graph_csr_ptrs(self):
74+
r"""TODO:"""
4675
self.fwd_row_offset_ptr = self._forward_graph.row_offset_ptr
4776
self.fwd_column_indices_ptr = self._forward_graph.column_indices_ptr
4877
self.fwd_eids_ptr = self._forward_graph.eids_ptr
@@ -52,31 +81,38 @@ def _get_graph_csr_ptrs(self):
5281
self.bwd_column_indices_ptr = self._backward_graph.column_indices_ptr
5382
self.bwd_eids_ptr = self._backward_graph.eids_ptr
5483
self.bwd_node_ids_ptr = self._backward_graph.node_ids_ptr
55-
84+
5685
def get_num_nodes(self):
86+
r"""Return the number of nodes in the static graph."""
5787
return self._num_nodes
58-
88+
5989
def get_num_edges(self):
90+
r"""Return the number of edges in the static graph."""
6091
return self._num_edges
61-
92+
6293
def get_ndata(self, field):
94+
r"""Returns the graph metadata."""
6395
if field in self._ndata:
6496
return self._ndata[field]
6597
else:
6698
return None
6799

68100
def set_ndata(self, field, val):
101+
r"""Sets the graph metadata."""
69102
self._ndata[field] = val
70-
103+
71104
def graph_type(self):
72-
# return "csr"
105+
r"""Returns the graph type."""
73106
return "csr_unsorted"
74-
107+
75108
def in_degrees(self):
76-
return np.array(self._forward_graph.out_degrees, dtype='int32')
77-
109+
r"""Returns the graph inwards node degree array."""
110+
return np.array(self._forward_graph.out_degrees, dtype="int32")
111+
78112
def out_degrees(self):
79-
return np.array(self._forward_graph.in_degrees, dtype='int32')
80-
113+
r"""Returns the graph outwards node degree array."""
114+
return np.array(self._forward_graph.in_degrees, dtype="int32")
115+
81116
def weighted_in_degrees(self):
82-
return np.array(self._forward_graph.weighted_out_degrees, dtype='int32')
117+
r"""TODO:"""
118+
return np.array(self._forward_graph.weighted_out_degrees, dtype="int32")

0 commit comments

Comments
 (0)