1
- import io
2
1
from typing import Optional , Any
3
2
from chdb import _chdb
4
3
11
10
raise ImportError ("Failed to import pyarrow" ) from None
12
11
13
12
13
+ _arrow_format = set ({"dataframe" , "arrowtable" })
14
+ _process_result_format_funs = {
15
+ "dataframe" : lambda x : to_df (x ),
16
+ "arrowtable" : lambda x : to_arrowTable (x ),
17
+ }
18
+
19
+
20
+ # return pyarrow table
21
+ def to_arrowTable (res ):
22
+ """convert res to arrow table"""
23
+ # try import pyarrow and pandas, if failed, raise ImportError with suggestion
24
+ try :
25
+ import pyarrow as pa # noqa
26
+ import pandas as pd # noqa
27
+ except ImportError as e :
28
+ print (f"ImportError: { e } " )
29
+ print ('Please install pyarrow and pandas via "pip install pyarrow pandas"' )
30
+ raise ImportError ("Failed to import pyarrow or pandas" ) from None
31
+ if len (res ) == 0 :
32
+ return pa .Table .from_batches ([], schema = pa .schema ([]))
33
+ return pa .RecordBatchFileReader (res .bytes ()).read_all ()
34
+
35
+
36
+ # return pandas dataframe
37
+ def to_df (r ):
38
+ """convert arrow table to Dataframe"""
39
+ t = to_arrowTable (r )
40
+ return t .to_pandas (use_threads = True )
41
+
42
+
14
43
class Connection :
15
44
def __init__ (self , connection_string : str ):
16
45
# print("Connection", connection_string)
@@ -22,7 +51,13 @@ def cursor(self) -> "Cursor":
22
51
return self ._cursor
23
52
24
53
def query (self , query : str , format : str = "CSV" ) -> Any :
25
- return self ._conn .query (query , format )
54
+ lower_output_format = format .lower ()
55
+ result_func = _process_result_format_funs .get (lower_output_format , lambda x : x )
56
+ if lower_output_format in _arrow_format :
57
+ format = "Arrow"
58
+
59
+ result = self ._conn .query (query , format )
60
+ return result_func (result )
26
61
27
62
def close (self ) -> None :
28
63
# print("close")
@@ -41,17 +76,103 @@ def __init__(self, connection):
41
76
def execute (self , query : str ) -> None :
42
77
self ._cursor .execute (query )
43
78
result_mv = self ._cursor .get_memview ()
44
- # print("get_result", result_mv)
45
79
if self ._cursor .has_error ():
46
80
raise Exception (self ._cursor .error_message ())
47
81
if self ._cursor .data_size () == 0 :
48
82
self ._current_table = None
49
83
self ._current_row = 0
84
+ self ._column_names = []
85
+ self ._column_types = []
50
86
return
51
- arrow_data = result_mv .tobytes ()
52
- reader = pa .ipc .open_stream (io .BytesIO (arrow_data ))
53
- self ._current_table = reader .read_all ()
54
- self ._current_row = 0
87
+
88
+ # Parse JSON data
89
+ json_data = result_mv .tobytes ().decode ("utf-8" )
90
+ import json
91
+
92
+ try :
93
+ # First line contains column names
94
+ # Second line contains column types
95
+ # Following lines contain data
96
+ lines = json_data .strip ().split ("\n " )
97
+ if len (lines ) < 2 :
98
+ self ._current_table = None
99
+ self ._current_row = 0
100
+ self ._column_names = []
101
+ self ._column_types = []
102
+ return
103
+
104
+ self ._column_names = json .loads (lines [0 ])
105
+ self ._column_types = json .loads (lines [1 ])
106
+
107
+ # Convert data rows
108
+ rows = []
109
+ for line in lines [2 :]:
110
+ if not line .strip ():
111
+ continue
112
+ row_data = json .loads (line )
113
+ converted_row = []
114
+ for val , type_info in zip (row_data , self ._column_types ):
115
+ # Handle NULL values first
116
+ if val is None :
117
+ converted_row .append (None )
118
+ continue
119
+
120
+ # Basic type conversion
121
+ try :
122
+ if type_info .startswith ("Int" ) or type_info .startswith ("UInt" ):
123
+ converted_row .append (int (val ))
124
+ elif type_info .startswith ("Float" ):
125
+ converted_row .append (float (val ))
126
+ elif type_info == "Bool" :
127
+ converted_row .append (bool (val ))
128
+ elif type_info == "String" or type_info == "FixedString" :
129
+ converted_row .append (str (val ))
130
+ elif type_info .startswith ("DateTime" ):
131
+ from datetime import datetime
132
+
133
+ # Check if the value is numeric (timestamp)
134
+ val_str = str (val )
135
+ if val_str .replace ("." , "" ).isdigit ():
136
+ converted_row .append (datetime .fromtimestamp (float (val )))
137
+ else :
138
+ # Handle datetime string formats
139
+ if "." in val_str : # Has microseconds
140
+ converted_row .append (
141
+ datetime .strptime (
142
+ val_str , "%Y-%m-%d %H:%M:%S.%f"
143
+ )
144
+ )
145
+ else : # No microseconds
146
+ converted_row .append (
147
+ datetime .strptime (val_str , "%Y-%m-%d %H:%M:%S" )
148
+ )
149
+ elif type_info .startswith ("Date" ):
150
+ from datetime import date , datetime
151
+
152
+ # Check if the value is numeric (days since epoch)
153
+ val_str = str (val )
154
+ if val_str .isdigit ():
155
+ converted_row .append (
156
+ date .fromtimestamp (float (val ) * 86400 )
157
+ )
158
+ else :
159
+ # Handle date string format
160
+ converted_row .append (
161
+ datetime .strptime (val_str , "%Y-%m-%d" ).date ()
162
+ )
163
+ else :
164
+ # For unsupported types, keep as string
165
+ converted_row .append (str (val ))
166
+ except (ValueError , TypeError ):
167
+ # If conversion fails, keep original value as string
168
+ converted_row .append (str (val ))
169
+ rows .append (tuple (converted_row ))
170
+
171
+ self ._current_table = rows
172
+ self ._current_row = 0
173
+
174
+ except json .JSONDecodeError as e :
175
+ raise Exception (f"Failed to parse JSON data: { e } " )
55
176
56
177
def commit (self ) -> None :
57
178
self ._cursor .commit ()
@@ -60,12 +181,10 @@ def fetchone(self) -> Optional[tuple]:
60
181
if not self ._current_table or self ._current_row >= len (self ._current_table ):
61
182
return None
62
183
63
- row_dict = {
64
- col : self ._current_table .column (col )[self ._current_row ].as_py ()
65
- for col in self ._current_table .column_names
66
- }
184
+ # Now self._current_table is a list of row tuples
185
+ row = self ._current_table [self ._current_row ]
67
186
self ._current_row += 1
68
- return tuple ( row_dict . values ())
187
+ return row
69
188
70
189
def fetchmany (self , size : int = 1 ) -> tuple :
71
190
if not self ._current_table :
@@ -99,6 +218,30 @@ def __next__(self) -> tuple:
99
218
raise StopIteration
100
219
return row
101
220
221
+ def column_names (self ) -> list :
222
+ """Return a list of column names from the last executed query"""
223
+ return self ._column_names if hasattr (self , "_column_names" ) else []
224
+
225
+ def column_types (self ) -> list :
226
+ """Return a list of column types from the last executed query"""
227
+ return self ._column_types if hasattr (self , "_column_types" ) else []
228
+
229
+ @property
230
+ def description (self ) -> list :
231
+ """
232
+ Return a description of the columns as per DB-API 2.0
233
+ Returns a list of 7-item tuples, each containing:
234
+ (name, type_code, display_size, internal_size, precision, scale, null_ok)
235
+ where only name and type_code are provided
236
+ """
237
+ if not hasattr (self , "_column_names" ) or not self ._column_names :
238
+ return []
239
+
240
+ return [
241
+ (name , type_info , None , None , None , None , None )
242
+ for name , type_info in zip (self ._column_names , self ._column_types )
243
+ ]
244
+
102
245
103
246
def connect (connection_string : str = ":memory:" ) -> Connection :
104
247
"""
0 commit comments