33"""
44Astroid hook for the dataclasses library
55"""
6- from typing import Generator , Tuple , Union
6+ from typing import Generator , List , Optional , Tuple
77
88from astroid import context , inference_tip
99from astroid .builder import parse
1010from astroid .const import PY37_PLUS , PY39_PLUS
11- from astroid .exceptions import InferenceError
11+ from astroid .exceptions import AstroidSyntaxError , InferenceError , MroError
1212from astroid .manager import AstroidManager
1313from astroid .nodes .node_classes import (
1414 AnnAssign ,
2626DATACLASSES_DECORATORS = frozenset (("dataclass" ,))
2727FIELD_NAME = "field"
2828DATACLASS_MODULE = "dataclasses"
29+ DEFAULT_FACTORY = "_HAS_DEFAULT_FACTORY" # based on typing.py
2930
3031
3132def is_decorated_with_dataclass (node , decorator_names = DATACLASSES_DECORATORS ):
@@ -57,17 +58,7 @@ def is_decorated_with_dataclass(node, decorator_names=DATACLASSES_DECORATORS):
5758def dataclass_transform (node : ClassDef ) -> None :
5859 """Rewrite a dataclass to be easily understood by pylint"""
5960
60- for assign_node in node .body :
61- if not isinstance (assign_node , AnnAssign ) or not isinstance (
62- assign_node .target , AssignName
63- ):
64- continue
65-
66- if _is_class_var (assign_node .annotation ) or _is_init_var (
67- assign_node .annotation
68- ):
69- continue
70-
61+ for assign_node in _get_dataclass_attributes (node ):
7162 name = assign_node .target .name
7263
7364 rhs_node = Unknown (
@@ -78,6 +69,167 @@ def dataclass_transform(node: ClassDef) -> None:
7869 rhs_node = AstroidManager ().visit_transforms (rhs_node )
7970 node .instance_attrs [name ] = [rhs_node ]
8071
72+ if not _check_generate_dataclass_init (node ):
73+ return
74+
75+ try :
76+ reversed_mro = reversed (node .mro ())
77+ except MroError :
78+ reversed_mro = [node ]
79+
80+ field_assigns = {}
81+ field_order = []
82+ for klass in (k for k in reversed_mro if is_decorated_with_dataclass (k )):
83+ for assign_node in _get_dataclass_attributes (klass , init = True ):
84+ name = assign_node .target .name
85+ if name not in field_assigns :
86+ field_order .append (name )
87+ field_assigns [name ] = assign_node
88+
89+ init_str = _generate_dataclass_init ([field_assigns [name ] for name in field_order ])
90+ try :
91+ init_node = parse (init_str )["__init__" ]
92+ except AstroidSyntaxError :
93+ pass
94+ else :
95+ init_node .parent = node
96+ init_node .lineno , init_node .col_offset = None , None
97+ node .locals ["__init__" ] = [init_node ]
98+
99+ root = node .root ()
100+ if DEFAULT_FACTORY not in root .locals :
101+ new_assign = parse (f"{ DEFAULT_FACTORY } = object()" ).body [0 ]
102+ new_assign .parent = root
103+ root .locals [DEFAULT_FACTORY ] = [new_assign .targets [0 ]]
104+
105+
106+ def _get_dataclass_attributes (node : ClassDef , init : bool = False ) -> Generator :
107+ """Yield the AnnAssign nodes of dataclass attributes for the node.
108+
109+ If init is True, also include InitVars, but exclude attributes from calls to
110+ field where init=False.
111+ """
112+ for assign_node in node .body :
113+ if not isinstance (assign_node , AnnAssign ) or not isinstance (
114+ assign_node .target , AssignName
115+ ):
116+ continue
117+
118+ if _is_class_var (assign_node .annotation ):
119+ continue
120+
121+ if init :
122+ value = assign_node .value
123+ if (
124+ isinstance (value , Call )
125+ and _looks_like_dataclass_field_call (value , check_scope = False )
126+ and any (
127+ keyword .arg == "init" and not keyword .value .bool_value ()
128+ for keyword in value .keywords
129+ )
130+ ):
131+ continue
132+ elif _is_init_var (assign_node .annotation ):
133+ continue
134+
135+ yield assign_node
136+
137+
138+ def _check_generate_dataclass_init (node : ClassDef ) -> bool :
139+ """Return True if we should generate an __init__ method for node.
140+
141+ This is True when:
142+ - node doesn't define its own __init__ method
143+ - the dataclass decorator was called *without* the keyword argument init=False
144+ """
145+ if "__init__" in node .locals :
146+ return False
147+
148+ found = None
149+
150+ for decorator_attribute in node .decorators .nodes :
151+ if not isinstance (decorator_attribute , Call ):
152+ continue
153+
154+ func = decorator_attribute .func
155+
156+ try :
157+ inferred = next (func .infer ())
158+ except (InferenceError , StopIteration ):
159+ continue
160+
161+ if not isinstance (inferred , FunctionDef ):
162+ continue
163+
164+ if (
165+ inferred .name in DATACLASSES_DECORATORS
166+ and inferred .root ().name == DATACLASS_MODULE
167+ ):
168+ found = decorator_attribute
169+
170+ if found is None :
171+ return True
172+
173+ # Check for keyword arguments of the form init=False
174+ return all (
175+ keyword .arg != "init" or keyword .value .bool_value ()
176+ for keyword in found .keywords
177+ )
178+
179+
180+ def _generate_dataclass_init (assigns : List [AnnAssign ]) -> str :
181+ """Return an init method for a dataclass given the targets."""
182+ target_names = []
183+ params = []
184+ assignments = []
185+
186+ for assign in assigns :
187+ name , annotation , value = assign .target .name , assign .annotation , assign .value
188+ target_names .append (name )
189+
190+ if _is_init_var (annotation ):
191+ init_var = True
192+ if isinstance (annotation , Subscript ):
193+ annotation = annotation .slice
194+ else :
195+ # Cannot determine type annotation for parameter from InitVar
196+ annotation = None
197+ assignment_str = ""
198+ else :
199+ init_var = False
200+ assignment_str = f"self.{ name } = { name } "
201+
202+ if annotation :
203+ param_str = f"{ name } : { annotation .as_string ()} "
204+ else :
205+ param_str = name
206+
207+ if value :
208+ if isinstance (value , Call ) and _looks_like_dataclass_field_call (
209+ value , check_scope = False
210+ ):
211+ result = _get_field_default (value )
212+
213+ default_type , default_node = result
214+ if default_type == "default" :
215+ param_str += f" = { default_node .as_string ()} "
216+ elif default_type == "default_factory" :
217+ param_str += f" = { DEFAULT_FACTORY } "
218+ assignment_str = (
219+ f"self.{ name } = { default_node .as_string ()} "
220+ f"if { name } is { DEFAULT_FACTORY } else { name } "
221+ )
222+ else :
223+ param_str += f" = { value .as_string ()} "
224+
225+ params .append (param_str )
226+ if not init_var :
227+ assignments .append (assignment_str )
228+
229+ params = ", " .join (["self" ] + params )
230+ assignments = "\n " .join (assignments ) if assignments else "pass"
231+ return f"def __init__({ params } ) -> None:\n { assignments } "
232+
81233
82234def infer_dataclass_attribute (
83235 node : Unknown , ctx : context .InferenceContext = None
@@ -107,17 +259,15 @@ def infer_dataclass_field_call(
107259) -> Generator :
108260 """Inference tip for dataclass field calls."""
109261 field_call = node .parent .value
110- result = _get_field_default (field_call )
111- if result is None :
262+ default_type , default = _get_field_default (field_call )
263+ if not default_type :
112264 yield Uninferable
265+ elif default_type == "default" :
266+ yield from default .infer (context = ctx )
113267 else :
114- default_type , default = result
115- if default_type == "default" :
116- yield from default .infer (context = ctx )
117- else :
118- new_call = parse (default .as_string ()).body [0 ].value
119- new_call .parent = field_call .parent
120- yield from new_call .infer (context = ctx )
268+ new_call = parse (default .as_string ()).body [0 ].value
269+ new_call .parent = field_call .parent
270+ yield from new_call .infer (context = ctx )
121271
122272
123273def _looks_like_dataclass_attribute (node : Unknown ) -> bool :
@@ -160,13 +310,14 @@ def _looks_like_dataclass_field_call(node: Call, check_scope: bool = True) -> bo
160310 return inferred .name == FIELD_NAME and inferred .root ().name == DATACLASS_MODULE
161311
162312
163- def _get_field_default (field_call : Call ) -> Union [ Tuple [str , NodeNG ], None ]:
313+ def _get_field_default (field_call : Call ) -> Tuple [str , Optional [ NodeNG ]]:
164314 """Return a the default value of a field call, and the corresponding keyword argument name.
165315
166316 field(default=...) results in the ... node
167317 field(default_factory=...) results in a Call node with func ... and no arguments
168318
169- If neither or both arguments are present, return None instead.
319+ If neither or both arguments are present, return ("", None) instead,
320+ indicating that there is not a valid default value.
170321 """
171322 default , default_factory = None , None
172323 for keyword in field_call .keywords :
@@ -187,7 +338,7 @@ def _get_field_default(field_call: Call) -> Union[Tuple[str, NodeNG], None]:
187338 new_call .postinit (func = default_factory )
188339 return "default_factory" , new_call
189340
190- return None
341+ return "" , None
191342
192343
193344def _is_class_var (node : NodeNG ) -> bool :
0 commit comments