Skip to content

Commit 253624d

Browse files
committed
fix: TaskiqAdminMiddleware work with dataclasses
1 parent 02f338a commit 253624d

File tree

7 files changed

+204
-109
lines changed

7 files changed

+204
-109
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ dev = [
6868
"freezegun>=1.5.5",
6969
"tzdata>=2025.2; sys_platform == 'win32'",
7070
"opentelemetry-test-utils (>=0.59b0,<1)",
71+
"polyfactory>=3.1.0",
7172
]
7273

7374
[project.urls]
@@ -172,8 +173,8 @@ lint.ignore = [
172173
"PLR0913", # Too many arguments for function call
173174
"D106", # Missing docstring in public nested class
174175
]
175-
exclude = [".venv/"]
176176
lint.mccabe = { max-complexity = 10 }
177+
exclude = [".venv/"]
177178
line-length = 88
178179

179180
[tool.ruff.lint.per-file-ignores]

taskiq/middlewares/taskiq_admin_middleware.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import aiohttp
88

99
from taskiq.abc.middleware import TaskiqMiddleware
10+
from taskiq.compat import model_dump
1011
from taskiq.message import TaskiqMessage
1112
from taskiq.result import TaskiqResult
1213

@@ -115,12 +116,13 @@ async def post_send(self, message: TaskiqMessage) -> None:
115116
116117
:param message: kicked message.
117118
"""
119+
dict_message: dict[str, Any] = model_dump(message)
118120
await self._spawn_request(
119121
f"/api/tasks/{message.task_id}/queued",
120122
{
121-
"args": message.args,
122-
"kwargs": message.kwargs,
123-
"labels": message.labels,
123+
"args": dict_message["args"],
124+
"kwargs": dict_message["kwargs"],
125+
"labels": dict_message["labels"],
124126
"queuedAt": self._now_iso(),
125127
"taskName": message.task_name,
126128
"worker": self.__ta_broker_name,
@@ -137,12 +139,13 @@ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
137139
:param message: incoming parsed taskiq message.
138140
:return: modified message.
139141
"""
142+
dict_message: dict[str, Any] = model_dump(message)
140143
await self._spawn_request(
141144
f"/api/tasks/{message.task_id}/started",
142145
{
143-
"args": message.args,
144-
"kwargs": message.kwargs,
145-
"labels": message.labels,
146+
"args": dict_message["args"],
147+
"kwargs": dict_message["kwargs"],
148+
"labels": dict_message["labels"],
146149
"startedAt": self._now_iso(),
147150
"taskName": message.task_name,
148151
"worker": self.__ta_broker_name,
@@ -164,12 +167,13 @@ async def post_execute(
164167
:param message: incoming message.
165168
:param result: result of execution for current task.
166169
"""
170+
dict_result: dict[str, Any] = model_dump(result)
167171
await self._spawn_request(
168172
f"/api/tasks/{message.task_id}/executed",
169173
{
170174
"finishedAt": self._now_iso(),
171175
"executionTime": result.execution_time,
172176
"error": None if result.error is None else repr(result.error),
173-
"returnValue": {"return_value": result.return_value},
177+
"returnValue": {"return_value": dict_result["return_value"]},
174178
},
175179
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
from aiohttp import web
3+
from aiohttp.test_utils import TestServer
4+
from typing_extensions import AsyncGenerator
5+
6+
from taskiq.brokers.inmemory_broker import InMemoryBroker
7+
from taskiq.brokers.shared_broker import async_shared_broker
8+
from taskiq.middlewares import TaskiqAdminMiddleware
9+
from tests.middlewares.admin_middleware.dto import (
10+
DataclassDTO,
11+
PydanticDTO,
12+
TypedDictDTO,
13+
)
14+
15+
16+
@pytest.fixture(scope="session")
17+
async def admin_api_server() -> AsyncGenerator[TestServer, None]:
18+
async def handle_queued(request: web.Request) -> web.Response:
19+
return web.json_response({"status": "ok"}, status=200)
20+
21+
async def handle_started(request: web.Request) -> web.Response:
22+
return web.json_response({"status": "ok"}, status=200)
23+
24+
async def handle_executed(request: web.Request) -> web.Response:
25+
return web.json_response({"status": "ok"}, status=200)
26+
27+
app = web.Application()
28+
app.router.add_post("/api/tasks/{task_id}/queued", handle_queued)
29+
app.router.add_post("/api/tasks/{task_id}/started", handle_started)
30+
app.router.add_post("/api/tasks/{task_id}/executed", handle_executed)
31+
32+
server = TestServer(app)
33+
await server.start_server()
34+
35+
yield server
36+
37+
# Останавливаем сервер после теста
38+
await server.close()
39+
40+
41+
@pytest.fixture
42+
async def broker_with_admin_middleware(
43+
admin_api_server: TestServer,
44+
) -> AsyncGenerator[InMemoryBroker, None]:
45+
broker = InMemoryBroker().with_middlewares(
46+
TaskiqAdminMiddleware(
47+
str(admin_api_server.make_url("/")), # URL тестового сервера
48+
"supersecret",
49+
taskiq_broker_name="InMemory",
50+
),
51+
)
52+
53+
broker.register_task(task_with_dataclass, task_name="task_with_dataclass")
54+
broker.register_task(task_with_typed_dict, task_name="task_with_typed_dict")
55+
broker.register_task(task_with_pydantic_model, task_name="task_with_pydantic_model")
56+
async_shared_broker.default_broker(broker)
57+
58+
await broker.startup()
59+
yield broker
60+
await broker.shutdown()
61+
62+
63+
async def task_with_dataclass(dto: DataclassDTO) -> None:
64+
assert dto
65+
66+
67+
async def task_with_typed_dict(dto: TypedDictDTO) -> None:
68+
assert dto
69+
70+
71+
async def task_with_pydantic_model(dto: PydanticDTO) -> None:
72+
assert dto
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from dataclasses import dataclass
2+
from typing import TypedDict
3+
4+
import pydantic
5+
6+
7+
@dataclass(frozen=True, slots=True)
8+
class DataclassNestedDTO:
9+
id: int
10+
name: str
11+
12+
13+
@dataclass(frozen=True, slots=True)
14+
class DataclassDTO:
15+
nested: DataclassNestedDTO
16+
recipients: list[str]
17+
subject: str
18+
attachments: list[str] | None = None
19+
text: str | None = None
20+
html: str | None = None
21+
22+
23+
class PydanticDTO(pydantic.BaseModel):
24+
number: int
25+
text: str
26+
flag: bool
27+
list: list[float]
28+
dictionary: dict[str, str] | None = None
29+
30+
31+
class TypedDictDTO(TypedDict):
32+
id: int
33+
name: str
34+
active: bool
35+
scores: list[int]
36+
metadata: dict[str, str] | None
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
from polyfactory.factories import BaseFactory, DataclassFactory, TypedDictFactory
3+
from polyfactory.factories.pydantic_factory import ModelFactory
4+
5+
from taskiq.brokers.inmemory_broker import InMemoryBroker
6+
from tests.middlewares.admin_middleware.dto import (
7+
DataclassDTO,
8+
PydanticDTO,
9+
TypedDictDTO,
10+
)
11+
12+
13+
class DataclassDTOFactory(DataclassFactory[DataclassDTO]):
14+
__model__ = DataclassDTO
15+
16+
17+
class TypedDictDTOFactory(TypedDictFactory[TypedDictDTO]):
18+
__model__ = TypedDictDTO
19+
20+
21+
class PydanticDTOFactory(ModelFactory[PydanticDTO]):
22+
__model__ = PydanticDTO
23+
24+
25+
class TestArgumentsFormattingInAdminMiddleware:
26+
@pytest.mark.parametrize(
27+
"dto_factory, task_name",
28+
[
29+
pytest.param(DataclassDTOFactory, "task_with_dataclass", id="dataclass"),
30+
pytest.param(TypedDictDTOFactory, "task_with_typed_dict", id="typeddict"),
31+
pytest.param(PydanticDTOFactory, "task_with_pydantic_model", id="pydantic"),
32+
],
33+
)
34+
async def test_when_task_dto_passed__then_middleware_successfully_send_request(
35+
self,
36+
broker_with_admin_middleware: InMemoryBroker,
37+
dto_factory: type[BaseFactory], # type: ignore[type-arg]
38+
task_name: str,
39+
) -> None:
40+
# given
41+
task_arguments = dto_factory.build()
42+
task = broker_with_admin_middleware.find_task(task_name)
43+
assert task is not None, f"Task {task_name} should be registered in the broker"
44+
45+
# when
46+
kicked_task = await task.kiq(task_arguments)
47+
await broker_with_admin_middleware.wait_all()
48+
49+
# then
50+
result = await kicked_task.get_result()
51+
# we just expect no errors during post_send/pre_execute/post_execute
52+
assert result.error is None

tests/middlewares/test_taskiq_admin_middleware.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

0 commit comments

Comments
 (0)