@@ -214,11 +214,13 @@ def __init__(
214
214
self .use_nxcg_cache = True
215
215
self .nxcg_graph = None
216
216
217
+ self .edge_type_key = edge_type_key
218
+ self .read_parallelism = read_parallelism
219
+ self .read_batch_size = read_batch_size
220
+
217
221
# Does not apply to undirected graphs
218
222
self .symmetrize_edges = symmetrize_edges
219
223
220
- self .edge_type_key = edge_type_key
221
-
222
224
# TODO: Consider this
223
225
# if not self.__graph_name:
224
226
# if incoming_graph_data is not None:
@@ -227,8 +229,8 @@ def __init__(
227
229
228
230
self ._loaded_incoming_graph_data = False
229
231
if self .graph_exists_in_db :
230
- self ._set_factory_methods ()
231
- self .__set_arangodb_backend_config (read_parallelism , read_batch_size )
232
+ self ._set_factory_methods (read_parallelism , read_batch_size )
233
+ self .__set_arangodb_backend_config ()
232
234
233
235
if overwrite_graph :
234
236
logger .info ("Overwriting graph..." )
@@ -284,7 +286,7 @@ def __init__(
284
286
# Init helper methods #
285
287
#######################
286
288
287
- def _set_factory_methods (self ) -> None :
289
+ def _set_factory_methods (self , read_parallelism : int , read_batch_size : int ) -> None :
288
290
"""Set the factory methods for the graph, _node, and _adj dictionaries.
289
291
290
292
The ArangoDB CRUD operations are handled by the modified dictionaries.
@@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None:
299
301
"""
300
302
301
303
base_args = (self .db , self .adb_graph )
304
+
302
305
node_args = (* base_args , self .default_node_type )
303
- adj_args = (
304
- * node_args ,
305
- self .edge_type_key ,
306
- self .edge_type_func ,
307
- self .__class__ .__name__ ,
306
+ node_args_with_read = (* node_args , read_parallelism , read_batch_size )
307
+
308
+ adj_args = (self .edge_type_key , self .edge_type_func , self .__class__ .__name__ )
309
+ adj_inner_args = (* node_args , * adj_args )
310
+ adj_outer_args = (
311
+ * node_args_with_read ,
312
+ * adj_args ,
313
+ self .symmetrize_edges ,
308
314
)
309
315
310
316
self .graph_attr_dict_factory = graph_dict_factory (* base_args )
311
317
312
- self .node_dict_factory = node_dict_factory (* node_args )
318
+ self .node_dict_factory = node_dict_factory (* node_args_with_read )
313
319
self .node_attr_dict_factory = node_attr_dict_factory (* base_args )
314
320
315
321
self .edge_attr_dict_factory = edge_attr_dict_factory (* base_args )
316
- self .adjlist_inner_dict_factory = adjlist_inner_dict_factory (* adj_args )
317
- self .adjlist_outer_dict_factory = adjlist_outer_dict_factory (
318
- * adj_args , self .symmetrize_edges
319
- )
320
-
321
- def __set_arangodb_backend_config (
322
- self , read_parallelism : int , read_batch_size : int
323
- ) -> None :
324
- if not all ([self ._host , self ._username , self ._password , self ._db_name ]):
325
- m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
326
- raise OSError (m )
322
+ self .adjlist_inner_dict_factory = adjlist_inner_dict_factory (* adj_inner_args )
323
+ self .adjlist_outer_dict_factory = adjlist_outer_dict_factory (* adj_outer_args )
327
324
325
+ def __set_arangodb_backend_config (self ) -> None :
328
326
config = nx .config .backends .arangodb
329
- config .host = self ._host
330
- config .username = self ._username
331
- config .password = self ._password
332
- config .db_name = self ._db_name
333
- config .read_parallelism = read_parallelism
334
- config .read_batch_size = read_batch_size
335
327
config .use_gpu = True # Only used by default if nx-cugraph is available
336
328
337
329
def __set_edge_collections_attributes (self , attributes : set [str ] | None ) -> None :
@@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None
345
337
self ._edge_collections_attributes .add ("_id" )
346
338
347
339
def __set_db (self , db : Any = None ) -> None :
348
- self ._host = os .getenv ("DATABASE_HOST" )
340
+ self ._hosts = os .getenv ("DATABASE_HOST" , "" ). split ( ", " )
349
341
self ._username = os .getenv ("DATABASE_USERNAME" )
350
342
self ._password = os .getenv ("DATABASE_PASSWORD" )
351
343
self ._db_name = os .getenv ("DATABASE_NAME" )
@@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None:
355
347
m = "arango.database.StandardDatabase"
356
348
raise TypeError (m )
357
349
358
- db .version ()
350
+ db .version () # make sure the connection is valid
359
351
self .__db = db
352
+ self ._db_name = db .name
353
+ self ._hosts = db ._conn ._hosts
354
+ self ._username , self ._password = db ._conn ._auth
360
355
return
361
356
362
- if not all ([self ._host , self ._username , self ._password , self ._db_name ]):
357
+ if not all ([self ._hosts , self ._username , self ._password , self ._db_name ]):
363
358
m = "Database environment variables not set. Can't connect to the database"
364
359
logger .warning (m )
365
360
self .__db = None
366
361
return
367
362
368
- self .__db = ArangoClient (hosts = self ._host , request_timeout = None ).db (
363
+ self .__db = ArangoClient (hosts = self ._hosts , request_timeout = None ).db (
369
364
self ._db_name , self ._username , self ._password , verify = True
370
365
)
371
366
0 commit comments