Skip to content

Commit 418f36b

Browse files
authored
chore: add tests for gokart parameters (#389)
* chore: add tests for gokart parameters * fix: handle parameters of `gokart.TaskOnKart`
1 parent 222c40a commit 418f36b

2 files changed

Lines changed: 24 additions & 9 deletions

File tree

gokart/mypy.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@
6060

6161
class TaskOnKartPlugin(Plugin):
6262
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
63-
if 'gokart.task.luigi.Task' in fullname:
64-
# gather attibutes from gokart.TaskOnKart
65-
# the transformation does not affect because the class has `__init__` method
63+
# The following gathers attributes from gokart.TaskOnKart such as `workspace_directory`
64+
# the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`.
65+
#
66+
# NOTE: `gokart.task.luigi.Task` condition is required for the release of luigi versions without py.typed
67+
if fullname in {'gokart.task.luigi.Task', 'luigi.task.Task'}:
6668
return self._task_on_kart_class_maker_callback
6769

6870
sym = self.lookup_fully_qualified(fullname)
@@ -209,7 +211,6 @@ def transform(self) -> bool:
209211
if ('__init__' not in info.names or info.names['__init__'].plugin_generated) and attributes:
210212
args = [attr.to_argument(info, of='__init__') for attr in attributes]
211213
add_method_to_class(self._api, self._cls, '__init__', args=args, return_type=NoneType())
212-
213214
info.metadata[METADATA_TAG] = {
214215
'attributes': [attr.serialize() for attr in attributes],
215216
}
@@ -330,6 +331,7 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
330331
info=cls.info,
331332
api=self._api,
332333
)
334+
333335
return list(found_attrs.values())
334336

335337
def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Expression]]:
@@ -404,9 +406,13 @@ def is_parameter_call(expr: Expression) -> bool:
404406
type_info = callee.node
405407
if type_info is None and isinstance(callee.expr, NameExpr):
406408
return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None
407-
if isinstance(type_info, TypeInfo) and PARAMETER_FULLNAME_MATCHER.match(type_info.fullname):
408-
return True
409+
elif isinstance(callee, NameExpr):
410+
type_info = callee.node
411+
else:
409412
return False
413+
414+
if isinstance(type_info, TypeInfo):
415+
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None
410416
return False
411417

412418

test/test_mypy.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ class MyTask(gokart.TaskOnKart):
1717
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
1818
foo: int = luigi.IntParameter() # type: ignore
1919
bar: str = luigi.Parameter() # type: ignore
20+
baz: bool = gokart.ExplicitBoolParameter()
2021
21-
MyTask(foo=1, bar='bar')
22+
23+
# TaskOnKart parameters:
24+
# - `complete_check_at_run`
25+
MyTask(foo=1, bar='bar', baz=False, complete_check_at_run=False)
2226
"""
2327

2428
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
@@ -37,15 +41,20 @@ class MyTask(gokart.TaskOnKart):
3741
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
3842
foo: int = luigi.IntParameter() # type: ignore
3943
bar: str = luigi.Parameter() # type: ignore
44+
baz: bool = gokart.ExplicitBoolParameter()
4045
4146
# issue: foo is int
4247
# not issue: bar is missing, because it can be set by config file.
43-
MyTask(foo='1')
48+
# TaskOnKart parameters:
49+
# - `complete_check_at_run`
50+
MyTask(foo='1', baz='not bool', complete_check_at_run='not bool')
4451
"""
4552

4653
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
4754
test_file.write(test_code.encode('utf-8'))
4855
test_file.flush()
4956
result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
5057
self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
51-
self.assertIn('Found 1 error in 1 file (checked 1 source file)', result[0])
58+
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
59+
self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
60+
self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0])

0 commit comments

Comments
 (0)