11import collections
22import inspect
33import textwrap
4- from typing import Any , List , Optional , Tuple , Union
4+ from typing import Any , Mapping , Optional , Tuple , Union
55
66import matplotlib .pyplot as plt
77import networkx as nx
88import numpy as np
9+ import numpy .typing as npt
910import pydot
1011from matplotlib .axes import Axes
1112from matplotlib .colors import rgb2hex
@@ -41,7 +42,7 @@ def child_id(self) -> str:
4142 return str (self ._child_id )
4243
4344 @property
44- def parent_id (self ) -> str :
45+ def parent_id (self ) -> str | None :
4546 if self ._parent_id :
4647 return str (self ._parent_id )
4748 return
@@ -76,11 +77,11 @@ def __init__(
7677 self ,
7778 baseclass : Any ,
7879 funcname : Optional [str ] = None ,
79- default_color : Optional [ str ] = "#000000" ,
80- func_override_color : Optional [ str ] = "#ff0000" ,
81- similarity_cutoff : Optional [ float ] = 0.75 ,
82- max_recursion_level : Optional [ int ] = 500 ,
83- classes_to_exclude : Optional [List [str ]] = None ,
80+ default_color : str = "#000000" ,
81+ func_override_color : str = "#ff0000" ,
82+ similarity_cutoff : float = 0.75 ,
83+ max_recursion_level : int = 500 ,
84+ classes_to_exclude : Optional [list [str ]] = None ,
8485 ):
8586
8687 self .baseclass = baseclass
@@ -90,20 +91,22 @@ def __init__(
9091 self ._nodenum : int = 0
9192 self ._node_list = [] # a list of unique ChildNodes
9293 self ._node_map = {} # map of global node index to node name
93- self ._override_src = collections .OrderedDict ()
94+ self ._override_src : Mapping [ int , str ] = collections .OrderedDict ()
9495 self ._override_src_files = {}
9596 self ._current_node = 1 # the current global node, must start at 1
9697 self ._default_color = default_color
9798 self ._override_color = func_override_color
9899 self ._graphviz_args_kwargs = {}
99100 self .similarity_container = None
100- self .similarity_results = None
101+ self .similarity_results : dict [ str , npt . NDArray ]
101102 self .similarity_cutoff = similarity_cutoff
102103 if classes_to_exclude is None :
103104 classes_to_exclude = []
104105 self .classes_to_exclude = classes_to_exclude
105106 self ._build ()
106- self ._node_map_r = {v : k for k , v in self ._node_map .items ()} # name to index
107+ self ._node_map_r : Mapping [str , int ] = {
108+ v : k for k , v in self ._node_map .items ()
109+ } # name to index
107110
108111 def _get_source_info (self , obj ) -> Optional [str ]:
109112 f = getattr (obj , self .funcname )
@@ -346,9 +349,9 @@ def plot_similarity(
346349 def build_interactive_graph (
347350 self ,
348351 include_similarity : bool = True ,
349- node_style : dict = None ,
350- edge_style : dict = None ,
351- similarity_edge_style : dict = None ,
352+ node_style : dict [ str , Any ] | None = None ,
353+ edge_style : dict [ str , Any ] | None = None ,
354+ similarity_edge_style : dict [ str , Any ] | None = None ,
352355 override_node_color : Union [str , tuple ] = None ,
353356 ** kwargs ,
354357 ) -> Network :
@@ -454,22 +457,22 @@ def build_interactive_graph(
454457 network_wrapper .from_nx (grph )
455458 return network_wrapper
456459
457- def get_source_code (self , node : Union [ str , int ] ) -> str :
460+ def get_source_code (self , node : int | str ) -> str :
458461 """
459462 retrieve the source code of the comparison function for a
460463 specified node
461464
462465 Parameters
463466 ----------
464- node: Union[str, int]
467+ node: int
465468 the node to fetch the source code for
466469
467470 Returns
468471 -------
469472 str
470473 a string containing the source code for the node.
471474 """
472- if node in self ._override_src :
475+ if isinstance ( node , int ) and node in self ._override_src :
473476 return self ._override_src [node ]
474477 elif isinstance (node , str ) and node in self ._node_map_r :
475478 node_id = self ._node_map_r [node ]
@@ -481,7 +484,9 @@ def get_source_code(self, node: Union[str, int]) -> str:
481484 )
482485 raise KeyError (f"Could not find node for { node } " )
483486
484- def get_multiple_source_code (self , node_1 : Union [str , int ], * args ) -> dict :
487+ def get_multiple_source_code (
488+ self , node_1 : str | int , * args
489+ ) -> Mapping [str | int , str ]:
485490 """
486491 Retrieve the source code for multiple nodes
487492
@@ -515,11 +520,11 @@ def display_code_comparison(self):
515520 display_code_compare (self )
516521
517522
518- def _validate_color (clr , default_rgb_tuple : tuple ) -> str :
523+ def _validate_color (clr , default_rgb_tuple : tuple [ float , float , float ] ) -> str :
519524 if clr is None :
520- return rgb2hex (default_rgb_tuple )
525+ return str ( rgb2hex (default_rgb_tuple ) )
521526 elif isinstance (clr , tuple ):
522- return rgb2hex (clr )
527+ return str ( rgb2hex (clr ) )
523528 elif isinstance (clr , str ):
524529 return clr
525530 msg = f"clr has unexpected type: { type (clr )} "
0 commit comments