1212
1313from stgraph .graph .static .csr import CSR
1414
15+
1516class StaticGraph (STGraphBase ):
16- def __init__ (self , edge_list , edge_weights , num_nodes ):
17+ r"""An abstract base class used to represent static graphs in STGraph.
18+
19+ This abstract class outlines the interface for defining a static graphs
20+ used in STGraph. As of now the static graph is implemented using the
21+ Compressed Sparse Row (CSR) format.
22+
23+ Example
24+ -------
25+
26+ .. code-block:: python
27+
28+ from stgraph.graph import StaticGraph
29+ from stgraph.dataset import HungaryCPDataLoader
30+
31+ hungary = HungaryCPDataLoader()
32+
33+ graph = StaticGraph(
34+ edge_list = hungary.get_edges(),
35+ edge_weights = hungary.get_edge_weights(),
36+ num_nodes = hungary.gdata["num_nodes"]
37+ )
38+
39+ """
40+
41+ def __init__ (self , edge_list , edge_weights , num_nodes ):
42+ """An abstract base class used to represent static graphs in STGraph."""
1743 super ().__init__ ()
1844 self ._num_nodes = num_nodes
1945 self ._num_edges = len (set (edge_list ))
20-
21- # console.log("Building forward edge list")
46+
2247 self ._prepare_edge_lst_fwd (edge_list )
23- # console.log("Creating forward graph")
24- self . _forward_graph = CSR ( self .fwd_edge_list , edge_weights , self ._num_nodes , is_edge_reverse = True )
25-
26- # console.log("Building backward edge list")
48+ self . _forward_graph = CSR (
49+ self .fwd_edge_list , edge_weights , self ._num_nodes , is_edge_reverse = True
50+ )
51+
2752 self ._prepare_edge_lst_bwd (self .fwd_edge_list )
28- # console.log("Creating backward graph")
2953 self ._backward_graph = CSR (self .bwd_edge_list , edge_weights , self ._num_nodes )
30-
31- # console.log("Getting CSR ptrs")
54+
3255 self ._get_graph_csr_ptrs ()
33-
34- def _prepare_edge_lst_fwd (self , edge_list ):
56+
57+ def _prepare_edge_lst_fwd (self , edge_list ):
58+ r"""TODO:"""
3559 edge_list_for_t = edge_list
36- edge_list_for_t .sort (key = lambda x : (x [1 ],x [0 ]))
37- edge_list_for_t = [(edge_list_for_t [j ][0 ],edge_list_for_t [j ][1 ],j ) for j in range (len (edge_list_for_t ))]
60+ edge_list_for_t .sort (key = lambda x : (x [1 ], x [0 ]))
61+ edge_list_for_t = [
62+ (edge_list_for_t [j ][0 ], edge_list_for_t [j ][1 ], j )
63+ for j in range (len (edge_list_for_t ))
64+ ]
3865 self .fwd_edge_list = edge_list_for_t
39-
40- def _prepare_edge_lst_bwd (self , edge_list ):
66+
67+ def _prepare_edge_lst_bwd (self , edge_list ):
68+ r"""TODO:"""
4169 edge_list_for_t = copy .deepcopy (edge_list )
4270 edge_list_for_t .sort ()
4371 self .bwd_edge_list = edge_list_for_t
44-
72+
4573 def _get_graph_csr_ptrs (self ):
74+ r"""TODO:"""
4675 self .fwd_row_offset_ptr = self ._forward_graph .row_offset_ptr
4776 self .fwd_column_indices_ptr = self ._forward_graph .column_indices_ptr
4877 self .fwd_eids_ptr = self ._forward_graph .eids_ptr
@@ -52,31 +81,38 @@ def _get_graph_csr_ptrs(self):
5281 self .bwd_column_indices_ptr = self ._backward_graph .column_indices_ptr
5382 self .bwd_eids_ptr = self ._backward_graph .eids_ptr
5483 self .bwd_node_ids_ptr = self ._backward_graph .node_ids_ptr
55-
84+
5685 def get_num_nodes (self ):
86+ r"""Return the number of nodes in the static graph."""
5787 return self ._num_nodes
58-
88+
5989 def get_num_edges (self ):
90+ r"""Return the number of edges in the static graph."""
6091 return self ._num_edges
61-
92+
6293 def get_ndata (self , field ):
94+ r"""Returns the graph metadata."""
6395 if field in self ._ndata :
6496 return self ._ndata [field ]
6597 else :
6698 return None
6799
68100 def set_ndata (self , field , val ):
101+ r"""Sets the graph metadata."""
69102 self ._ndata [field ] = val
70-
103+
71104 def graph_type (self ):
72- # return "csr "
105+ r"""Returns the graph type."" "
73106 return "csr_unsorted"
74-
107+
75108 def in_degrees (self ):
76- return np .array (self ._forward_graph .out_degrees , dtype = 'int32' )
77-
109+ r"""Returns the graph inwards node degree array."""
110+ return np .array (self ._forward_graph .out_degrees , dtype = "int32" )
111+
78112 def out_degrees (self ):
79- return np .array (self ._forward_graph .in_degrees , dtype = 'int32' )
80-
113+ r"""Returns the graph outwards node degree array."""
114+ return np .array (self ._forward_graph .in_degrees , dtype = "int32" )
115+
81116 def weighted_in_degrees (self ):
82- return np .array (self ._forward_graph .weighted_out_degrees , dtype = 'int32' )
117+ r"""TODO:"""
118+ return np .array (self ._forward_graph .weighted_out_degrees , dtype = "int32" )
0 commit comments