Skip to content

Commit 83c61f7

Browse files
authored
feat: support numpy decoder (#78)
Introduce functionality to convert numpy arrays to lists when the target type is a list. This enhancement checks for numpy installation and integrates the conversion into existing type decoding processes.
1 parent 5e5d4df commit 83c61f7

File tree

6 files changed

+422
-10
lines changed

6 files changed

+422
-10
lines changed

sqlspec/_typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
606606

607607

608608
FSSPEC_INSTALLED = bool(find_spec("fsspec"))
609+
NUMPY_INSTALLED = bool(find_spec("numpy"))
609610
OBSTORE_INSTALLED = bool(find_spec("obstore"))
610611
PGVECTOR_INSTALLED = bool(find_spec("pgvector"))
611612

@@ -617,6 +618,7 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
617618
"FSSPEC_INSTALLED",
618619
"LITESTAR_INSTALLED",
619620
"MSGSPEC_INSTALLED",
621+
"NUMPY_INSTALLED",
620622
"OBSTORE_INSTALLED",
621623
"OPENTELEMETRY_INSTALLED",
622624
"PGVECTOR_INSTALLED",

sqlspec/driver/mixins/_result_tools.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sqlspec.exceptions import SQLSpecError
1616
from sqlspec.typing import (
1717
CATTRS_INSTALLED,
18+
NUMPY_INSTALLED,
1819
ModelDTOT,
1920
ModelT,
2021
attrs_asdict,
@@ -41,12 +42,36 @@
4142

4243

4344
_DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time}
45+
46+
47+
def _is_list_type_target(target_type: Any) -> bool:
48+
"""Check if target type is a list type (e.g., list[float])."""
49+
try:
50+
return hasattr(target_type, "__origin__") and target_type.__origin__ is list
51+
except (AttributeError, TypeError):
52+
return False
53+
54+
55+
def _convert_numpy_to_list(target_type: Any, value: Any) -> Any:
56+
"""Convert numpy array to list if target is a list type."""
57+
if not NUMPY_INSTALLED:
58+
return value
59+
60+
import numpy as np
61+
62+
if isinstance(value, np.ndarray) and _is_list_type_target(target_type):
63+
return value.tolist()
64+
65+
return value
66+
67+
4468
_DEFAULT_TYPE_DECODERS: Final[list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]] = [
4569
(lambda x: x is UUID, lambda t, v: t(v.hex)),
4670
(lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())),
4771
(lambda x: x is datetime.date, lambda t, v: t(v.isoformat())),
4872
(lambda x: x is datetime.time, lambda t, v: t(v.isoformat())),
4973
(lambda x: x is Enum, lambda t, v: t(v.value)),
74+
(_is_list_type_target, _convert_numpy_to_list),
5075
]
5176

5277

@@ -63,6 +88,13 @@ def _default_msgspec_deserializer(
6388
Returns:
6489
Converted value or original value if conversion not applicable
6590
"""
91+
# Handle numpy arrays first for list types
92+
if NUMPY_INSTALLED:
93+
import numpy as np
94+
95+
if isinstance(value, np.ndarray) and _is_list_type_target(target_type):
96+
return value.tolist()
97+
6698
if type_decoders:
6799
for predicate, decoder in type_decoders:
68100
if predicate(target_type):
@@ -71,17 +103,19 @@ def _default_msgspec_deserializer(
71103
if target_type is UUID and isinstance(value, UUID):
72104
return value.hex
73105

74-
if target_type in _DATETIME_TYPES:
75-
try:
76-
return value.isoformat()
77-
except AttributeError:
78-
pass
106+
if target_type in _DATETIME_TYPES and hasattr(value, "isoformat"):
107+
return value.isoformat() # pyright: ignore
79108

80109
if isinstance(target_type, type) and issubclass(target_type, Enum) and isinstance(value, Enum):
81110
return value.value
82111

83-
if isinstance(value, target_type):
84-
return value
112+
# Check if value is already the correct type (but avoid parameterized generics)
113+
try:
114+
if isinstance(target_type, type) and isinstance(value, target_type):
115+
return value
116+
except TypeError:
117+
# Handle parameterized generics like list[int] which can't be used with isinstance
118+
pass
85119

86120
if isinstance(target_type, type):
87121
try:
@@ -190,6 +224,25 @@ def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) ->
190224
logger.debug("Field name transformation failed for msgspec schema: %s", e)
191225
transformed_data = data
192226

227+
# Pre-process numpy arrays to lists before msgspec conversion
228+
if NUMPY_INSTALLED:
229+
try:
230+
import numpy as np
231+
232+
def _convert_numpy_arrays_in_data(obj: Any) -> Any:
233+
"""Recursively convert numpy arrays to lists in data structures."""
234+
if isinstance(obj, np.ndarray):
235+
return obj.tolist()
236+
if isinstance(obj, dict):
237+
return {k: _convert_numpy_arrays_in_data(v) for k, v in obj.items()}
238+
if isinstance(obj, (list, tuple)):
239+
return type(obj)(_convert_numpy_arrays_in_data(item) for item in obj)
240+
return obj
241+
242+
transformed_data = _convert_numpy_arrays_in_data(transformed_data)
243+
except ImportError:
244+
pass
245+
193246
if not isinstance(transformed_data, Sequence):
194247
return convert(obj=transformed_data, type=schema_type, from_attributes=True, dec_hook=deserializer)
195248
return convert(obj=transformed_data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]

sqlspec/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
FSSPEC_INSTALLED,
1313
LITESTAR_INSTALLED,
1414
MSGSPEC_INSTALLED,
15+
NUMPY_INSTALLED,
1516
OBSTORE_INSTALLED,
1617
OPENTELEMETRY_INSTALLED,
1718
PGVECTOR_INSTALLED,
@@ -187,6 +188,7 @@ class StorageMixin(MixinOf(DriverProtocol)): ...
187188
"FSSPEC_INSTALLED",
188189
"LITESTAR_INSTALLED",
189190
"MSGSPEC_INSTALLED",
191+
"NUMPY_INSTALLED",
190192
"OBSTORE_INSTALLED",
191193
"OPENTELEMETRY_INSTALLED",
192194
"PGVECTOR_INSTALLED",

tests/unit/test_driver/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)