Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions procrastinate/psycopg_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,20 @@ async def close_async(self) -> None:
await self._async_pool.close()
self._async_pool = None

def _wrap_value(self, value: Any) -> Any:
@classmethod
def _wrap_value(cls, value: Any) -> Any:
if isinstance(value, dict):
return psycopg.types.json.Jsonb(value)
elif isinstance(value, list):
return [self._wrap_value(item) for item in value]
return [cls._wrap_value(item) for item in value]
elif isinstance(value, tuple):
return tuple([self._wrap_value(item) for item in value])
return tuple([cls._wrap_value(item) for item in value])
else:
return value

def _wrap_json(self, arguments: dict[str, Any]):
return {key: self._wrap_value(value) for key, value in arguments.items()}
@classmethod
def wrap_json(cls, arguments: dict[str, Any]) -> dict[str, Any]:
return {key: cls._wrap_value(value) for key, value in arguments.items()}

@contextlib.asynccontextmanager
async def _get_cursor(
Expand All @@ -204,14 +206,14 @@ async def _get_cursor(
@wrap_exceptions()
async def execute_query_async(self, query: LiteralString, **arguments: Any) -> None:
async with self._get_cursor() as cursor:
await cursor.execute(query, self._wrap_json(arguments))
await cursor.execute(query, self.wrap_json(arguments))

@wrap_exceptions()
async def execute_query_one_async(
self, query: LiteralString, **arguments: Any
) -> dict[str, Any]:
async with self._get_cursor() as cursor:
await cursor.execute(query, self._wrap_json(arguments))
await cursor.execute(query, self.wrap_json(arguments))

result = await cursor.fetchone()

Expand All @@ -224,7 +226,7 @@ async def execute_query_all_async(
self, query: LiteralString, **arguments: Any
) -> list[dict[str, Any]]:
async with self._get_cursor() as cursor:
await cursor.execute(query, self._wrap_json(arguments))
await cursor.execute(query, self.wrap_json(arguments))

return await cursor.fetchall()

Expand Down
41 changes: 39 additions & 2 deletions tests/integration/test_psycopg_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
import asgiref.sync
import attr
import pytest

from procrastinate import exceptions, manager, psycopg_connector, sync_psycopg_connector
from psycopg.types.json import Jsonb

from procrastinate import (
PsycopgConnector,
exceptions,
manager,
psycopg_connector,
sync_psycopg_connector,
)


@pytest.fixture
Expand Down Expand Up @@ -283,3 +290,33 @@ async def test_get_sync_connector__not_open(not_opened_psycopg_connector):
assert isinstance(sync, sync_psycopg_connector.SyncPsycopgConnector)
assert not_opened_psycopg_connector.get_sync_connector() is sync
assert sync._pool_args == not_opened_psycopg_connector._pool_args


@pytest.mark.parametrize(
"arguments, expected",
[
({"a": "a"}, {"a": "a"}),
({"a": ("a", "b")}, ({"a": ("a", "b")})),
({"a": ["a", "b"]}, ({"a": ["a", "b"]})),
],
)
def test_wrap_json_makes_correct_psycopg_dict__simple(arguments, expected):
result = PsycopgConnector.wrap_json(arguments)

assert result == expected


def test_wrap_json_makes_correct_psycopg_dict__inner_dict():
result = PsycopgConnector.wrap_json({"a": {"b": "c"}})

assert set(result.keys()) == {"a"}
assert isinstance(result["a"], Jsonb)
assert result["a"].obj == {"b": "c"}


def test_wrap_json_makes_correct_psycopg_dict__inner_list():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are pretty simple and straightforward. But if you know how to make it better, i will do

result = PsycopgConnector.wrap_json({"a": [{"b": "c"}]})

assert set(result.keys()) == {"a"}
assert isinstance(result["a"][0], Jsonb)
assert result["a"][0].obj == {"b": "c"}