Skip to content

Commit 095acc1

Browse files
david-yz-liuPierre-Sassoulas
authored andcommitted
Generate synthetic __init__ method for dataclasses
1 parent 249d8e2 commit 095acc1

File tree

3 files changed

+457
-25
lines changed

3 files changed

+457
-25
lines changed

ChangeLog

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ Release date: TBA
1515

1616
* ``BaseContainer`` is now public, and will replace ``_BaseContainer`` completely in astroid 3.0.
1717

18+
* Add inference for dataclass initializer method.
19+
20+
Closes PyCQA/pylint#3201
1821

1922
What's New in astroid 2.7.1?
2023
============================

astroid/brain/brain_dataclasses.py

Lines changed: 176 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
"""
44
Astroid hook for the dataclasses library
55
"""
6-
from typing import Generator, Tuple, Union
6+
from typing import Generator, List, Optional, Tuple
77

88
from astroid import context, inference_tip
99
from astroid.builder import parse
1010
from astroid.const import PY37_PLUS, PY39_PLUS
11-
from astroid.exceptions import InferenceError
11+
from astroid.exceptions import AstroidSyntaxError, InferenceError, MroError
1212
from astroid.manager import AstroidManager
1313
from astroid.nodes.node_classes import (
1414
AnnAssign,
@@ -26,6 +26,7 @@
2626
DATACLASSES_DECORATORS = frozenset(("dataclass",))
2727
FIELD_NAME = "field"
2828
DATACLASS_MODULE = "dataclasses"
29+
DEFAULT_FACTORY = "_HAS_DEFAULT_FACTORY" # based on typing.py
2930

3031

3132
def is_decorated_with_dataclass(node, decorator_names=DATACLASSES_DECORATORS):
@@ -57,17 +58,7 @@ def is_decorated_with_dataclass(node, decorator_names=DATACLASSES_DECORATORS):
5758
def dataclass_transform(node: ClassDef) -> None:
5859
"""Rewrite a dataclass to be easily understood by pylint"""
5960

60-
for assign_node in node.body:
61-
if not isinstance(assign_node, AnnAssign) or not isinstance(
62-
assign_node.target, AssignName
63-
):
64-
continue
65-
66-
if _is_class_var(assign_node.annotation) or _is_init_var(
67-
assign_node.annotation
68-
):
69-
continue
70-
61+
for assign_node in _get_dataclass_attributes(node):
7162
name = assign_node.target.name
7263

7364
rhs_node = Unknown(
@@ -78,6 +69,167 @@ def dataclass_transform(node: ClassDef) -> None:
7869
rhs_node = AstroidManager().visit_transforms(rhs_node)
7970
node.instance_attrs[name] = [rhs_node]
8071

72+
if not _check_generate_dataclass_init(node):
73+
return
74+
75+
try:
76+
reversed_mro = reversed(node.mro())
77+
except MroError:
78+
reversed_mro = [node]
79+
80+
field_assigns = {}
81+
field_order = []
82+
for klass in (k for k in reversed_mro if is_decorated_with_dataclass(k)):
83+
for assign_node in _get_dataclass_attributes(klass, init=True):
84+
name = assign_node.target.name
85+
if name not in field_assigns:
86+
field_order.append(name)
87+
field_assigns[name] = assign_node
88+
89+
init_str = _generate_dataclass_init([field_assigns[name] for name in field_order])
90+
try:
91+
init_node = parse(init_str)["__init__"]
92+
except AstroidSyntaxError:
93+
pass
94+
else:
95+
init_node.parent = node
96+
init_node.lineno, init_node.col_offset = None, None
97+
node.locals["__init__"] = [init_node]
98+
99+
root = node.root()
100+
if DEFAULT_FACTORY not in root.locals:
101+
new_assign = parse(f"{DEFAULT_FACTORY} = object()").body[0]
102+
new_assign.parent = root
103+
root.locals[DEFAULT_FACTORY] = [new_assign.targets[0]]
104+
105+
106+
def _get_dataclass_attributes(node: ClassDef, init: bool = False) -> Generator:
107+
"""Yield the AnnAssign nodes of dataclass attributes for the node.
108+
109+
If init is True, also include InitVars, but exclude attributes from calls to
110+
field where init=False.
111+
"""
112+
for assign_node in node.body:
113+
if not isinstance(assign_node, AnnAssign) or not isinstance(
114+
assign_node.target, AssignName
115+
):
116+
continue
117+
118+
if _is_class_var(assign_node.annotation):
119+
continue
120+
121+
if init:
122+
value = assign_node.value
123+
if (
124+
isinstance(value, Call)
125+
and _looks_like_dataclass_field_call(value, check_scope=False)
126+
and any(
127+
keyword.arg == "init" and not keyword.value.bool_value()
128+
for keyword in value.keywords
129+
)
130+
):
131+
continue
132+
elif _is_init_var(assign_node.annotation):
133+
continue
134+
135+
yield assign_node
136+
137+
138+
def _check_generate_dataclass_init(node: ClassDef) -> bool:
139+
"""Return True if we should generate an __init__ method for node.
140+
141+
This is True when:
142+
- node doesn't define its own __init__ method
143+
- the dataclass decorator was called *without* the keyword argument init=False
144+
"""
145+
if "__init__" in node.locals:
146+
return False
147+
148+
found = None
149+
150+
for decorator_attribute in node.decorators.nodes:
151+
if not isinstance(decorator_attribute, Call):
152+
continue
153+
154+
func = decorator_attribute.func
155+
156+
try:
157+
inferred = next(func.infer())
158+
except (InferenceError, StopIteration):
159+
continue
160+
161+
if not isinstance(inferred, FunctionDef):
162+
continue
163+
164+
if (
165+
inferred.name in DATACLASSES_DECORATORS
166+
and inferred.root().name == DATACLASS_MODULE
167+
):
168+
found = decorator_attribute
169+
170+
if found is None:
171+
return True
172+
173+
# Check for keyword arguments of the form init=False
174+
return all(
175+
keyword.arg != "init" or keyword.value.bool_value()
176+
for keyword in found.keywords
177+
)
178+
179+
180+
def _generate_dataclass_init(assigns: List[AnnAssign]) -> str:
181+
"""Return an init method for a dataclass given the targets."""
182+
target_names = []
183+
params = []
184+
assignments = []
185+
186+
for assign in assigns:
187+
name, annotation, value = assign.target.name, assign.annotation, assign.value
188+
target_names.append(name)
189+
190+
if _is_init_var(annotation):
191+
init_var = True
192+
if isinstance(annotation, Subscript):
193+
annotation = annotation.slice
194+
else:
195+
# Cannot determine type annotation for parameter from InitVar
196+
annotation = None
197+
assignment_str = ""
198+
else:
199+
init_var = False
200+
assignment_str = f"self.{name} = {name}"
201+
202+
if annotation:
203+
param_str = f"{name}: {annotation.as_string()}"
204+
else:
205+
param_str = name
206+
207+
if value:
208+
if isinstance(value, Call) and _looks_like_dataclass_field_call(
209+
value, check_scope=False
210+
):
211+
result = _get_field_default(value)
212+
213+
default_type, default_node = result
214+
if default_type == "default":
215+
param_str += f" = {default_node.as_string()}"
216+
elif default_type == "default_factory":
217+
param_str += f" = {DEFAULT_FACTORY}"
218+
assignment_str = (
219+
f"self.{name} = {default_node.as_string()} "
220+
f"if {name} is {DEFAULT_FACTORY} else {name}"
221+
)
222+
else:
223+
param_str += f" = {value.as_string()}"
224+
225+
params.append(param_str)
226+
if not init_var:
227+
assignments.append(assignment_str)
228+
229+
params = ", ".join(["self"] + params)
230+
assignments = "\n ".join(assignments) if assignments else "pass"
231+
return f"def __init__({params}) -> None:\n {assignments}"
232+
81233

82234
def infer_dataclass_attribute(
83235
node: Unknown, ctx: context.InferenceContext = None
@@ -107,17 +259,15 @@ def infer_dataclass_field_call(
107259
) -> Generator:
108260
"""Inference tip for dataclass field calls."""
109261
field_call = node.parent.value
110-
result = _get_field_default(field_call)
111-
if result is None:
262+
default_type, default = _get_field_default(field_call)
263+
if not default_type:
112264
yield Uninferable
265+
elif default_type == "default":
266+
yield from default.infer(context=ctx)
113267
else:
114-
default_type, default = result
115-
if default_type == "default":
116-
yield from default.infer(context=ctx)
117-
else:
118-
new_call = parse(default.as_string()).body[0].value
119-
new_call.parent = field_call.parent
120-
yield from new_call.infer(context=ctx)
268+
new_call = parse(default.as_string()).body[0].value
269+
new_call.parent = field_call.parent
270+
yield from new_call.infer(context=ctx)
121271

122272

123273
def _looks_like_dataclass_attribute(node: Unknown) -> bool:
@@ -160,13 +310,14 @@ def _looks_like_dataclass_field_call(node: Call, check_scope: bool = True) -> bo
160310
return inferred.name == FIELD_NAME and inferred.root().name == DATACLASS_MODULE
161311

162312

163-
def _get_field_default(field_call: Call) -> Union[Tuple[str, NodeNG], None]:
313+
def _get_field_default(field_call: Call) -> Tuple[str, Optional[NodeNG]]:
164314
"""Return a the default value of a field call, and the corresponding keyword argument name.
165315
166316
field(default=...) results in the ... node
167317
field(default_factory=...) results in a Call node with func ... and no arguments
168318
169-
If neither or both arguments are present, return None instead.
319+
If neither or both arguments are present, return ("", None) instead,
320+
indicating that there is not a valid default value.
170321
"""
171322
default, default_factory = None, None
172323
for keyword in field_call.keywords:
@@ -187,7 +338,7 @@ def _get_field_default(field_call: Call) -> Union[Tuple[str, NodeNG], None]:
187338
new_call.postinit(func=default_factory)
188339
return "default_factory", new_call
189340

190-
return None
341+
return "", None
191342

192343

193344
def _is_class_var(node: NodeNG) -> bool:

0 commit comments

Comments
 (0)