@@ -82,7 +82,6 @@ def node_count(self) -> int:
8282 """
8383 Returns:
8484 the number of nodes in the graph
85-
8685 """
8786 return self ._graph_info (["nodeCount" ]) # type: ignore
8887
@@ -191,7 +190,6 @@ def drop(self, failIfMissing: bool = False) -> "Series[str]":
191190
192191 Returns:
193192 the result of the drop operation
194-
195193 """
196194 result = self ._query_runner .call_procedure (
197195 endpoint = "gds.graph.drop" ,
@@ -205,7 +203,6 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
205203 """
206204 Returns:
207205 the creation time of the graph
208-
209206 """
210207 return self ._graph_info (["creationTime" ])
211208
@@ -236,12 +233,36 @@ def __repr__(self) -> str:
236233
237234 def visualize (
238235 self ,
239- notebook : bool = True ,
240236 node_count : int = 100 ,
237+ directed : bool = True ,
241238 center_nodes : Optional [List [int ]] = None ,
242- include_node_properties : List [str ] = None ,
239+ include_node_properties : Optional [ List [str ] ] = None ,
243240 color_property : Optional [str ] = None ,
241+ size_property : Optional [str ] = None ,
242+ rel_weight_property : Optional [str ] = None ,
243+ notebook : bool = True ,
244+ px_height : int = 750 ,
245+ theme : str = "dark" ,
244246 ) -> Any :
247+ """
248+ Visualize the `Graph` in an interactive graphical interface.
249+ The graph will be sampled down to specified `node_count` to limit computationally expensive rendering.
250+
251+ Args:
252+ node_count: number of nodes in the graph to be visualized
253+ directed: whether or not to display relationships as directed
254+ center_nodes: nodes around subgraph will be sampled, if sampling is necessary
255+ include_node_properties: node properties to include for mouse-over inspection
256+ color_property: node property that determines node categories for coloring. Default is to use node labels
257+ size_property: node property that determines the size of nodes. Default is to compute a page rank for this
258+ rel_weight_property: relationship property that determines width of relationships
259+ notebook: whether or not the code is run in a notebook
260+ px_height: the height of the graphic containing output the visualization
261+ theme: coloring theme for the visualization. "light" or "dark"
262+
263+ Returns:
264+ an interactive graphical visualization of the specified graph
265+ """
245266 visual_graph = self ._name
246267 if self .node_count () > node_count :
247268 visual_graph = str (uuid4 ())
@@ -256,14 +277,19 @@ def visualize(
256277 custom_error = False ,
257278 )
258279
259- pr_prop = str (uuid4 ())
260- self ._query_runner .call_procedure (
261- endpoint = "gds.pageRank.mutate" ,
262- params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = pr_prop )),
263- custom_error = False ,
264- )
280+ # Make sure we always have at least a size property so that we can run `gds.graph.nodeProperties.stream`
281+ if size_property is None :
282+ size_property = str (uuid4 ())
283+ self ._query_runner .call_procedure (
284+ endpoint = "gds.pageRank.mutate" ,
285+ params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = size_property )),
286+ custom_error = False ,
287+ )
288+ clean_up_size_prop = True
289+ else :
290+ clean_up_size_prop = False
265291
266- node_properties = [pr_prop ]
292+ node_properties = [size_property ]
267293 if include_node_properties is not None :
268294 node_properties .extend (include_node_properties )
269295
@@ -295,11 +321,18 @@ def visualize(
295321 result .columns .name = None
296322 node_properties_df = result
297323
298- relationships_df = self ._query_runner .call_procedure (
299- endpoint = "gds.graph.relationships.stream" ,
300- params = CallParameters (graph_name = visual_graph ),
301- custom_error = False ,
302- )
324+ if rel_weight_property is None :
325+ relationships_df = self ._query_runner .call_procedure (
326+ endpoint = "gds.graph.relationships.stream" ,
327+ params = CallParameters (graph_name = visual_graph ),
328+ custom_error = False ,
329+ )
330+ else :
331+ relationships_df = self ._query_runner .call_procedure (
332+ endpoint = "gds.graph.relationshipProperty.stream" ,
333+ params = CallParameters (graph_name = visual_graph , properties = rel_weight_property ),
334+ custom_error = False ,
335+ )
303336
304337 # Clean up
305338 if visual_graph != self ._name :
@@ -308,10 +341,10 @@ def visualize(
308341 params = CallParameters (graph_name = visual_graph ),
309342 custom_error = False ,
310343 )
311- else :
344+ elif clean_up_size_prop :
312345 self ._query_runner .call_procedure (
313346 endpoint = "gds.graph.nodeProperties.drop" ,
314- params = CallParameters (graph_name = visual_graph , nodeProperties = pr_prop ),
347+ params = CallParameters (graph_name = visual_graph , nodeProperties = size_property ),
315348 custom_error = False ,
316349 )
317350
@@ -320,19 +353,21 @@ def visualize(
320353 net = Network (
321354 notebook = True if notebook else False ,
322355 cdn_resources = "remote" if notebook else "local" ,
323- bgcolor = "#222222" , # Dark background
324- font_color = "white" ,
325- height = "750px" , # Modify according to your screen size
356+ directed = directed ,
357+ bgcolor = "#222222" if theme == "dark" else "#FDFDFD" ,
358+ font_color = "white" if theme == "dark" else "black" ,
359+ height = f"{ px_height } px" ,
326360 width = "100%" ,
327361 )
328362
329363 if color_property is None :
330- color_map = {label : self ._random_bright_color () for label in self .node_labels ()}
364+ color_map = {label : self ._random_bright_color (theme ) for label in self .node_labels ()}
331365 else :
332366 color_map = {
333- prop_val : self ._random_bright_color () for prop_val in node_properties_df [color_property ].unique ()
367+ prop_val : self ._random_bright_color (theme ) for prop_val in node_properties_df [color_property ].unique ()
334368 }
335369
370+ # Add all the nodes
336371 for _ , node in node_properties_df .iterrows ():
337372 title = f"Node ID: { node ['nodeId' ]} \n Labels: { node ['nodeLabels' ]} "
338373 if include_node_properties is not None :
@@ -347,17 +382,22 @@ def visualize(
347382
348383 net .add_node (
349384 int (node ["nodeId" ]),
350- value = node [pr_prop ],
385+ value = node [size_property ],
351386 color = color ,
352387 title = title ,
353388 )
354389
355390 # Add all the relationships
356- net .add_edges (zip (relationships_df ["sourceNodeId" ], relationships_df ["targetNodeId" ]))
391+ for _ , rel in relationships_df .iterrows ():
392+ if rel_weight_property is None :
393+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = f"Type: { rel ['relationshipType' ]} " )
394+ else :
395+ title = f"Type: { rel ['relationshipType' ]} \n { rel_weight_property } = { rel ['rel_weight_property' ]} "
396+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = title , value = rel [rel_weight_property ])
357397
358398 return net .show (f"{ self ._name } .html" )
359399
360400 @staticmethod
361- def _random_bright_color () -> str :
362- h = random . randint ( 0 , 255 ) / 255.0
363- return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (h , 0.7 , 1.0 )))
401+ def _random_bright_color (theme ) -> str :
402+ l = 0.7 if theme == "dark" else 0.4
403+ return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (random . random (), l , 1.0 )))
0 commit comments