1515from sqlspec .exceptions import SQLSpecError
1616from sqlspec .typing import (
1717 CATTRS_INSTALLED ,
18+ NUMPY_INSTALLED ,
1819 ModelDTOT ,
1920 ModelT ,
2021 attrs_asdict ,
4142
4243
4344_DATETIME_TYPES : Final [set [type ]] = {datetime .datetime , datetime .date , datetime .time }
45+
46+
47+ def _is_list_type_target (target_type : Any ) -> bool :
48+ """Check if target type is a list type (e.g., list[float])."""
49+ try :
50+ return hasattr (target_type , "__origin__" ) and target_type .__origin__ is list
51+ except (AttributeError , TypeError ):
52+ return False
53+
54+
55+ def _convert_numpy_to_list (target_type : Any , value : Any ) -> Any :
56+ """Convert numpy array to list if target is a list type."""
57+ if not NUMPY_INSTALLED :
58+ return value
59+
60+ import numpy as np
61+
62+ if isinstance (value , np .ndarray ) and _is_list_type_target (target_type ):
63+ return value .tolist ()
64+
65+ return value
66+
67+
4468_DEFAULT_TYPE_DECODERS : Final [list [tuple [Callable [[Any ], bool ], Callable [[Any , Any ], Any ]]]] = [
4569 (lambda x : x is UUID , lambda t , v : t (v .hex )),
4670 (lambda x : x is datetime .datetime , lambda t , v : t (v .isoformat ())),
4771 (lambda x : x is datetime .date , lambda t , v : t (v .isoformat ())),
4872 (lambda x : x is datetime .time , lambda t , v : t (v .isoformat ())),
4973 (lambda x : x is Enum , lambda t , v : t (v .value )),
74+ (_is_list_type_target , _convert_numpy_to_list ),
5075]
5176
5277
@@ -63,6 +88,13 @@ def _default_msgspec_deserializer(
6388 Returns:
6489 Converted value or original value if conversion not applicable
6590 """
91+ # Handle numpy arrays first for list types
92+ if NUMPY_INSTALLED :
93+ import numpy as np
94+
95+ if isinstance (value , np .ndarray ) and _is_list_type_target (target_type ):
96+ return value .tolist ()
97+
6698 if type_decoders :
6799 for predicate , decoder in type_decoders :
68100 if predicate (target_type ):
@@ -71,17 +103,19 @@ def _default_msgspec_deserializer(
71103 if target_type is UUID and isinstance (value , UUID ):
72104 return value .hex
73105
74- if target_type in _DATETIME_TYPES :
75- try :
76- return value .isoformat ()
77- except AttributeError :
78- pass
106+ if target_type in _DATETIME_TYPES and hasattr (value , "isoformat" ):
107+ return value .isoformat () # pyright: ignore
79108
80109 if isinstance (target_type , type ) and issubclass (target_type , Enum ) and isinstance (value , Enum ):
81110 return value .value
82111
83- if isinstance (value , target_type ):
84- return value
112+ # Check if value is already the correct type (but avoid parameterized generics)
113+ try :
114+ if isinstance (target_type , type ) and isinstance (value , target_type ):
115+ return value
116+ except TypeError :
117+ # Handle parameterized generics like list[int] which can't be used with isinstance
118+ pass
85119
86120 if isinstance (target_type , type ):
87121 try :
@@ -190,6 +224,25 @@ def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) ->
190224 logger .debug ("Field name transformation failed for msgspec schema: %s" , e )
191225 transformed_data = data
192226
227+ # Pre-process numpy arrays to lists before msgspec conversion
228+ if NUMPY_INSTALLED :
229+ try :
230+ import numpy as np
231+
232+ def _convert_numpy_arrays_in_data (obj : Any ) -> Any :
233+ """Recursively convert numpy arrays to lists in data structures."""
234+ if isinstance (obj , np .ndarray ):
235+ return obj .tolist ()
236+ if isinstance (obj , dict ):
237+ return {k : _convert_numpy_arrays_in_data (v ) for k , v in obj .items ()}
238+ if isinstance (obj , (list , tuple )):
239+ return type (obj )(_convert_numpy_arrays_in_data (item ) for item in obj )
240+ return obj
241+
242+ transformed_data = _convert_numpy_arrays_in_data (transformed_data )
243+ except ImportError :
244+ pass
245+
193246 if not isinstance (transformed_data , Sequence ):
194247 return convert (obj = transformed_data , type = schema_type , from_attributes = True , dec_hook = deserializer )
195248 return convert (obj = transformed_data , type = list [schema_type ], from_attributes = True , dec_hook = deserializer ) # type: ignore[valid-type]
0 commit comments