Skip to content

Commit 74247d4

Browse files
feat(spider-py): Add support for client-end task graph grouping and chaining. (#191)
Co-authored-by: Lin Zhihao <[email protected]>
1 parent 64519e9 commit 74247d4

File tree

9 files changed

+430
-40
lines changed

9 files changed

+430
-40
lines changed

python/spider-py/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,14 @@ typing-extensions = false # Disable type extensions for 3.10 compatibility
7474
[tool.ruff.lint.per-file-ignores]
7575
"tests/**" = [
7676
"INP001", # Allow implicit namespace package for tests
77+
"PLR2004", # Allow use of magic value
7778
"S101", # Allow use of `assert` (security warning)
7879
"S603", # Allow use of `subprocess.Popen` (security warning)
7980
"T201", # Allow use of `print` (testing)
8081
]
8182

83+
[tool.ruff.lint.flake8-self]
84+
ignore-names = ["_impl"]
85+
8286
[tool.ruff.lint.pydocstyle]
8387
ignore-decorators = ["typing.override"]

python/spider-py/src/spider_py/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Spider package root."""
22

3+
from spider_py.client import chain, group, TaskContext, TaskGraph
34
from spider_py.type import Double, Float, Int8, Int16, Int32, Int64
45

56
__all__ = [
@@ -9,4 +10,8 @@
910
"Int16",
1011
"Int32",
1112
"Int64",
13+
"TaskContext",
14+
"TaskGraph",
15+
"chain",
16+
"group",
1217
]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
11
"""Spider python client."""
2+
3+
from .task import TaskContext
4+
from .task_graph import chain, group, TaskGraph
5+
6+
__all__ = [
7+
"TaskContext",
8+
"TaskGraph",
9+
"chain",
10+
"group",
11+
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Spider client Data module."""
2+
3+
4+
class Data:
5+
"""Represents a spider client data."""
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Spider client task module."""
2+
3+
import inspect
4+
from collections.abc import Callable
5+
from types import FunctionType, GenericAlias
6+
from typing import get_args, get_origin
7+
8+
from spider_py import core
9+
from spider_py.client.data import Data
10+
from spider_py.core import TaskInput, TaskOutput, TaskOutputData, TaskOutputValue
11+
from spider_py.type import to_tdl_type_str
12+
13+
14+
class TaskContext:
15+
"""Spider task context."""
16+
17+
# TODO: Implement task context for use in task executor
18+
19+
20+
# NOTE: This type alias is for clarification purposes only. It does not enforce static type checks.
21+
# Instead, we rely on the runtime check to ensure the first argument is `TaskContext`. To statically
22+
# enforce the first argument to be `TaskContext`, `Protocol` is required, which is not compatible
23+
# with `Callable` without explicit type casting.
24+
TaskFunction = Callable[..., object]
25+
26+
27+
def _is_tuple(t: type | GenericAlias) -> bool:
28+
"""
29+
:param t:
30+
:return: Whether t is a tuple.
31+
"""
32+
return get_origin(t) is tuple
33+
34+
35+
def _validate_and_convert_params(signature: inspect.Signature) -> list[TaskInput]:
36+
"""
37+
Validates the task parameters and converts them into a list of `core.TaskInput`.
38+
:param signature:
39+
:return: The converted task parameters.
40+
:raises TypeError: If the parameters are invalid.
41+
"""
42+
params = list(signature.parameters.values())
43+
inputs = []
44+
if not params or params[0].annotation is not TaskContext:
45+
msg = "First argument is not a TaskContext."
46+
raise TypeError(msg)
47+
for param in params[1:]:
48+
if param.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}:
49+
msg = "Variadic parameters are not supported."
50+
raise TypeError(msg)
51+
if param.annotation is inspect.Parameter.empty:
52+
msg = "Parameters must have type annotation."
53+
raise TypeError(msg)
54+
tdl_type_str = to_tdl_type_str(param.annotation)
55+
inputs.append(TaskInput(tdl_type_str, None))
56+
return inputs
57+
58+
59+
def _validate_and_convert_return(signature: inspect.Signature) -> list[TaskOutput]:
60+
"""
61+
Validates the task returns and converts them into a list of `core.TaskOutput`.
62+
:param signature:
63+
:return: The converted task returns.
64+
:raises TypeError: If the return type is invalid.
65+
"""
66+
returns = signature.return_annotation
67+
outputs = []
68+
if returns is inspect.Parameter.empty:
69+
msg = "Return type must have type annotation."
70+
raise TypeError(msg)
71+
72+
if not _is_tuple(returns):
73+
tdl_type_str = to_tdl_type_str(returns)
74+
if returns is Data:
75+
outputs.append(TaskOutput(tdl_type_str, TaskOutputData()))
76+
else:
77+
outputs.append(TaskOutput(tdl_type_str, TaskOutputValue()))
78+
return outputs
79+
80+
args = get_args(returns)
81+
if Ellipsis in args:
82+
msg = "Variable-length tuple return types are not supported."
83+
raise TypeError(msg)
84+
for arg in args:
85+
tdl_type_str = to_tdl_type_str(arg)
86+
if arg is Data:
87+
outputs.append(TaskOutput(tdl_type_str, TaskOutputData()))
88+
else:
89+
outputs.append(TaskOutput(tdl_type_str, TaskOutputValue()))
90+
return outputs
91+
92+
93+
def create_task(func: TaskFunction) -> core.Task:
94+
"""
95+
Creates a core Task object from the task function.
96+
:param func:
97+
:return: The created core Task object.
98+
:raise TypeError: If the function signature contains unsupported types.
99+
"""
100+
if not isinstance(func, FunctionType):
101+
msg = "`func` is not a function."
102+
raise TypeError(msg)
103+
signature = inspect.signature(func)
104+
return core.Task(
105+
function_name=func.__qualname__,
106+
task_inputs=_validate_and_convert_params(signature),
107+
task_outputs=_validate_and_convert_return(signature),
108+
)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Spider client TaskGraph module."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
from spider_py import core
8+
from spider_py.client.task import create_task, TaskFunction
9+
10+
if TYPE_CHECKING:
11+
from collections.abc import Sequence
12+
13+
14+
class TaskGraph:
15+
"""
16+
Represents a client-side task graph.
17+
18+
This class is a wrapper of `spider_py.core.Task`.
19+
"""
20+
21+
def __init__(self) -> None:
22+
"""Initializes TaskGraph."""
23+
self._impl = core.TaskGraph()
24+
25+
26+
def group(tasks: Sequence[TaskFunction | TaskGraph]) -> TaskGraph:
27+
"""
28+
Groups task functions and task graph into a single task graph.
29+
:param tasks: List of task functions or task graphs.
30+
:return: The new task graph.
31+
"""
32+
graph = TaskGraph()
33+
for task in tasks:
34+
if callable(task):
35+
graph._impl.add_task(create_task(task))
36+
else:
37+
graph._impl.merge_graph(task._impl)
38+
39+
return graph
40+
41+
42+
def chain(parent: TaskFunction | TaskGraph, child: TaskFunction | TaskGraph) -> TaskGraph:
43+
"""
44+
Chains two task functions or task graphs into a single task graph.
45+
:param parent:
46+
:param child:
47+
:return: The new task graph.
48+
:raises TypeError: If the parent outputs and child inputs do not match.
49+
"""
50+
parent_core_graph: core.TaskGraph
51+
child_core_graph: core.TaskGraph
52+
53+
if callable(parent):
54+
parent_core_graph = core.TaskGraph()
55+
parent_core_graph.add_task(create_task(parent))
56+
else:
57+
parent_core_graph = parent._impl
58+
59+
if callable(child):
60+
child_core_graph = core.TaskGraph()
61+
child_core_graph.add_task(create_task(child))
62+
else:
63+
child_core_graph = child._impl
64+
65+
graph = TaskGraph()
66+
graph._impl = core.TaskGraph.chain_graph(parent_core_graph, child_core_graph)
67+
return graph

python/spider-py/src/spider_py/core/task.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,26 @@ class TaskInputOutput:
1919

2020
TaskInputValue = bytes
2121
TaskInputData = DataId
22-
TaskInput = TaskInputOutput | TaskInputValue | TaskInputData
22+
23+
24+
@dataclass
25+
class TaskInput:
26+
"""Represents a task input"""
27+
28+
type: str
29+
value: TaskInputData | TaskInputOutput | None
30+
2331

2432
TaskOutputValue = bytes
2533
TaskOutputData = DataId
26-
TaskOutput = TaskOutputValue | TaskOutputData
34+
35+
36+
@dataclass
37+
class TaskOutput:
38+
"""Represents a task output"""
39+
40+
type: str
41+
value: TaskOutputData | TaskOutputValue
2742

2843

2944
class TaskState(IntEnum):

0 commit comments

Comments
 (0)