Skip to content

Commit bd47753

Browse files
authored
new: fully support parameterized db object (#70)
* new: fully support parameterized `db` object * fix: `hosts` * fix: docstring * new: support `use_gpu` algorithm parameter * new: `test_multiple_graph_sessions`
1 parent 9f59085 commit bd47753

File tree

13 files changed

+158
-75
lines changed

13 files changed

+158
-75
lines changed

README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,16 @@ import nx_arangodb as nxadb
166166

167167
G = nxadb.Graph(name="MyGraph")
168168

169+
# Option 1: Use Global Config
169170
nx.config.backends.arangodb.use_gpu = False
170-
171171
nx.pagerank(G)
172172
nx.betweenness_centrality(G)
173173
# ...
174-
175174
nx.config.backends.arangodb.use_gpu = True
175+
176+
# Option 2: Use Local Config
177+
nx.pagerank(G, use_gpu=False)
178+
nx.betweenness_centrality(G, use_gpu=False)
176179
```
177180

178181
<p align="center">

_nx_arangodb/__init__.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,7 @@ def get_info():
7474
for key in info_keys:
7575
del d[key]
7676

77-
d["default_config"] = {
78-
"host": None,
79-
"username": None,
80-
"password": None,
81-
"db_name": None,
82-
"read_parallelism": None,
83-
"read_batch_size": None,
84-
"use_gpu": True,
85-
}
77+
d["default_config"] = {"use_gpu": True}
8678

8779
return d
8880

doc/algorithms/index.rst

+5-2
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,17 @@ You can also force-run algorithms on CPU even if ``nx-cugraph`` is installed:
4343
4444
G = nxadb.Graph(name="MyGraph")
4545
46+
# Option 1: Use Global Config
4647
nx.config.backends.arangodb.use_gpu = False
47-
4848
nx.pagerank(G)
4949
nx.betweenness_centrality(G)
5050
# ...
51-
5251
nx.config.backends.arangodb.use_gpu = True
5352
53+
# Option 2: Use Local Config
54+
nx.pagerank(G, use_gpu=False)
55+
nx.betweenness_centrality(G, use_gpu=False)
56+
5457
5558
.. image:: ../_static/dispatch.png
5659
:align: center

doc/nx_arangodb.ipynb

+1-5
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,7 @@
236236
"outputs": [],
237237
"source": [
238238
"# 5. Run an algorithm (CPU)\n",
239-
"nx.config.backends.arangodb.use_gpu = False # Optional\n",
240-
"\n",
241-
"res = nx.pagerank(G)"
239+
"res = nx.pagerank(G, use_gpu=False)"
242240
]
243241
},
244242
{
@@ -357,8 +355,6 @@
357355
"source": [
358356
"# 4. Run an algorithm (GPU)\n",
359357
"# See *Package Installation* to install nx-cugraph ^\n",
360-
"nx.config.backends.arangodb.use_gpu = True\n",
361-
"\n",
362358
"res = nx.pagerank(G)"
363359
]
364360
},

nx_arangodb/classes/dict/adj.py

+16
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def adjlist_outer_dict_factory(
105105
db: StandardDatabase,
106106
graph: Graph,
107107
default_node_type: str,
108+
read_parallelism: int,
109+
read_batch_size: int,
108110
edge_type_key: str,
109111
edge_type_func: Callable[[str, str], str],
110112
graph_type: str,
@@ -115,6 +117,8 @@ def adjlist_outer_dict_factory(
115117
db,
116118
graph,
117119
default_node_type,
120+
read_parallelism,
121+
read_batch_size,
118122
edge_type_key,
119123
edge_type_func,
120124
graph_type,
@@ -1454,6 +1458,12 @@ class AdjListOuterDict(UserDict[str, AdjListInnerDict]):
14541458
symmetrize_edges_if_directed : bool
14551459
Whether to add the reverse edge if the graph is directed.
14561460
1461+
read_parallelism : int
1462+
The number of parallel threads to use for reading data in _fetch_all.
1463+
1464+
read_batch_size : int
1465+
The number of documents to read in each batch in _fetch_all.
1466+
14571467
Example
14581468
-------
14591469
>>> g = nxadb.Graph(name="MyGraph")
@@ -1467,6 +1477,8 @@ def __init__(
14671477
db: StandardDatabase,
14681478
graph: Graph,
14691479
default_node_type: str,
1480+
read_parallelism: int,
1481+
read_batch_size: int,
14701482
edge_type_key: str,
14711483
edge_type_func: Callable[[str, str], str],
14721484
graph_type: str,
@@ -1489,6 +1501,8 @@ def __init__(
14891501
self.edge_type_key = edge_type_key
14901502
self.edge_type_func = edge_type_func
14911503
self.default_node_type = default_node_type
1504+
self.read_parallelism = read_parallelism
1505+
self.read_batch_size = read_batch_size
14921506
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(
14931507
db,
14941508
graph,
@@ -1853,6 +1867,8 @@ def _fetch_all(self) -> None:
18531867
is_directed=True,
18541868
is_multigraph=self.is_multigraph,
18551869
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
1870+
read_parallelism=self.read_parallelism,
1871+
read_batch_size=self.read_batch_size,
18561872
)
18571873

18581874
# Even if the Graph is undirected,

nx_arangodb/classes/dict/node.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,20 @@
4040

4141

4242
def node_dict_factory(
43-
db: StandardDatabase, graph: Graph, default_node_type: str
43+
db: StandardDatabase,
44+
graph: Graph,
45+
default_node_type: str,
46+
read_parallelism: int,
47+
read_batch_size: int,
4448
) -> Callable[..., NodeDict]:
4549
"""Factory function for creating a NodeDict."""
46-
return lambda: NodeDict(db, graph, default_node_type)
50+
return lambda: NodeDict(
51+
db,
52+
graph,
53+
default_node_type,
54+
read_parallelism,
55+
read_batch_size,
56+
)
4757

4858

4959
def node_attr_dict_factory(
@@ -250,6 +260,12 @@ class NodeDict(UserDict[str, NodeAttrDict]):
250260
default_node_type : str
251261
The default node type for the graph.
252262
263+
read_parallelism : int
264+
The number of parallel threads to use for reading data in _fetch_all.
265+
266+
read_batch_size : int
267+
The number of documents to read in each batch in _fetch_all.
268+
253269
Example
254270
-------
255271
>>> G = nxadb.Graph("MyGraph")
@@ -262,6 +278,8 @@ def __init__(
262278
db: StandardDatabase,
263279
graph: Graph,
264280
default_node_type: str,
281+
read_parallelism: int,
282+
read_batch_size: int,
265283
*args: Any,
266284
**kwargs: Any,
267285
):
@@ -271,6 +289,9 @@ def __init__(
271289
self.db = db
272290
self.graph = graph
273291
self.default_node_type = default_node_type
292+
self.read_parallelism = read_parallelism
293+
self.read_batch_size = read_batch_size
294+
274295
self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph)
275296

276297
self.FETCHED_ALL_DATA = False
@@ -472,6 +493,8 @@ def _fetch_all(self):
472493
is_directed=False, # not used
473494
is_multigraph=False, # not used
474495
symmetrize_edges_if_directed=False, # not used
496+
read_parallelism=self.read_parallelism,
497+
read_batch_size=self.read_batch_size,
475498
)
476499

477500
for node_id, node_data in node_dict.items():

nx_arangodb/classes/function.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def get_arangodb_graph(
4747
is_directed: bool,
4848
is_multigraph: bool,
4949
symmetrize_edges_if_directed: bool,
50+
read_parallelism: int,
51+
read_batch_size: int,
5052
) -> Tuple[
5153
NodeDict,
5254
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict,
@@ -142,11 +144,10 @@ def get_arangodb_graph(
142144
if not load_adj_dict and not load_coo:
143145
metagraph["edgeCollections"] = {}
144146

145-
config = nx.config.backends.arangodb
146-
assert config.db_name
147-
assert config.host
148-
assert config.username
149-
assert config.password
147+
hosts = adb_graph._conn._hosts
148+
hosts = hosts.split(",") if type(hosts) is str else hosts
149+
db_name = adb_graph._conn._db_name
150+
username, password = adb_graph._conn._auth
150151

151152
(
152153
node_dict,
@@ -157,20 +158,20 @@ def get_arangodb_graph(
157158
vertex_ids_to_index,
158159
edge_values,
159160
) = NetworkXLoader.load_into_networkx(
160-
config.db_name,
161+
database=db_name,
161162
metagraph=metagraph,
162-
hosts=[config.host],
163-
username=config.username,
164-
password=config.password,
163+
hosts=hosts,
164+
username=username,
165+
password=password,
165166
load_adj_dict=load_adj_dict,
166167
load_coo=load_coo,
167168
load_all_vertex_attributes=load_all_vertex_attributes,
168169
load_all_edge_attributes=load_all_edge_attributes,
169170
is_directed=is_directed,
170171
is_multigraph=is_multigraph,
171172
symmetrize_edges_if_directed=symmetrize_edges_if_directed,
172-
parallelism=config.read_parallelism,
173-
batch_size=config.read_batch_size,
173+
parallelism=read_parallelism,
174+
batch_size=read_batch_size,
174175
)
175176

176177
return (

nx_arangodb/classes/graph.py

+27-32
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,13 @@ def __init__(
214214
self.use_nxcg_cache = True
215215
self.nxcg_graph = None
216216

217+
self.edge_type_key = edge_type_key
218+
self.read_parallelism = read_parallelism
219+
self.read_batch_size = read_batch_size
220+
217221
# Does not apply to undirected graphs
218222
self.symmetrize_edges = symmetrize_edges
219223

220-
self.edge_type_key = edge_type_key
221-
222224
# TODO: Consider this
223225
# if not self.__graph_name:
224226
# if incoming_graph_data is not None:
@@ -227,8 +229,8 @@ def __init__(
227229

228230
self._loaded_incoming_graph_data = False
229231
if self.graph_exists_in_db:
230-
self._set_factory_methods()
231-
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
232+
self._set_factory_methods(read_parallelism, read_batch_size)
233+
self.__set_arangodb_backend_config()
232234

233235
if overwrite_graph:
234236
logger.info("Overwriting graph...")
@@ -284,7 +286,7 @@ def __init__(
284286
# Init helper methods #
285287
#######################
286288

287-
def _set_factory_methods(self) -> None:
289+
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
288290
"""Set the factory methods for the graph, _node, and _adj dictionaries.
289291
290292
The ArangoDB CRUD operations are handled by the modified dictionaries.
@@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None:
299301
"""
300302

301303
base_args = (self.db, self.adb_graph)
304+
302305
node_args = (*base_args, self.default_node_type)
303-
adj_args = (
304-
*node_args,
305-
self.edge_type_key,
306-
self.edge_type_func,
307-
self.__class__.__name__,
306+
node_args_with_read = (*node_args, read_parallelism, read_batch_size)
307+
308+
adj_args = (self.edge_type_key, self.edge_type_func, self.__class__.__name__)
309+
adj_inner_args = (*node_args, *adj_args)
310+
adj_outer_args = (
311+
*node_args_with_read,
312+
*adj_args,
313+
self.symmetrize_edges,
308314
)
309315

310316
self.graph_attr_dict_factory = graph_dict_factory(*base_args)
311317

312-
self.node_dict_factory = node_dict_factory(*node_args)
318+
self.node_dict_factory = node_dict_factory(*node_args_with_read)
313319
self.node_attr_dict_factory = node_attr_dict_factory(*base_args)
314320

315321
self.edge_attr_dict_factory = edge_attr_dict_factory(*base_args)
316-
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_args)
317-
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(
318-
*adj_args, self.symmetrize_edges
319-
)
320-
321-
def __set_arangodb_backend_config(
322-
self, read_parallelism: int, read_batch_size: int
323-
) -> None:
324-
if not all([self._host, self._username, self._password, self._db_name]):
325-
m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
326-
raise OSError(m)
322+
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_inner_args)
323+
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(*adj_outer_args)
327324

325+
def __set_arangodb_backend_config(self) -> None:
328326
config = nx.config.backends.arangodb
329-
config.host = self._host
330-
config.username = self._username
331-
config.password = self._password
332-
config.db_name = self._db_name
333-
config.read_parallelism = read_parallelism
334-
config.read_batch_size = read_batch_size
335327
config.use_gpu = True # Only used by default if nx-cugraph is available
336328

337329
def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None:
@@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None
345337
self._edge_collections_attributes.add("_id")
346338

347339
def __set_db(self, db: Any = None) -> None:
348-
self._host = os.getenv("DATABASE_HOST")
340+
self._hosts = os.getenv("DATABASE_HOST", "").split(",")
349341
self._username = os.getenv("DATABASE_USERNAME")
350342
self._password = os.getenv("DATABASE_PASSWORD")
351343
self._db_name = os.getenv("DATABASE_NAME")
@@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None:
355347
m = "arango.database.StandardDatabase"
356348
raise TypeError(m)
357349

358-
db.version()
350+
db.version() # make sure the connection is valid
359351
self.__db = db
352+
self._db_name = db.name
353+
self._hosts = db._conn._hosts
354+
self._username, self._password = db._conn._auth
360355
return
361356

362-
if not all([self._host, self._username, self._password, self._db_name]):
357+
if not all([self._hosts, self._username, self._password, self._db_name]):
363358
m = "Database environment variables not set. Can't connect to the database"
364359
logger.warning(m)
365360
self.__db = None
366361
return
367362

368-
self.__db = ArangoClient(hosts=self._host, request_timeout=None).db(
363+
self.__db = ArangoClient(hosts=self._hosts, request_timeout=None).db(
369364
self._db_name, self._username, self._password, verify=True
370365
)
371366

nx_arangodb/classes/multigraph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def __init__(
229229
# Init helper methods #
230230
#######################
231231

232-
def _set_factory_methods(self) -> None:
233-
super()._set_factory_methods()
232+
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
233+
super()._set_factory_methods(read_parallelism, read_batch_size)
234234
self.edge_key_dict_factory = edge_key_dict_factory(
235235
self.db,
236236
self.adb_graph,

0 commit comments

Comments
 (0)