Skip to content

Commit ae670ba

Browse files
authored
Merge pull request #308 from chdb-io/fixDfArrowTableOutput
Fix df arrow table output
2 parents 85c6bda + fb4c992 commit ae670ba

File tree

6 files changed

+333
-286
lines changed

6 files changed

+333
-286
lines changed

Diff for: chdb/dbapi/connections.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def cursor(self, cursor=None):
5757
return Cursor(self)
5858
return Cursor(self)
5959

60-
def query(self, sql, fmt="ArrowStream"):
60+
def query(self, sql, fmt="CSV"):
6161
"""Execute a query and return the raw result."""
6262
if self._closed:
6363
raise err.InterfaceError("Connection closed")

Diff for: chdb/dbapi/cursors.py

+30-18
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# executemany only supports simple bulk insert.
66
# You can use it to load large dataset.
77
RE_INSERT_VALUES = re.compile(
8-
r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" +
9-
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
10-
r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
11-
re.IGNORECASE | re.DOTALL)
8+
r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)"
9+
+ r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
10+
+ r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
11+
re.IGNORECASE | re.DOTALL,
12+
)
1213

1314

1415
class Cursor(object):
@@ -131,13 +132,17 @@ def execute(self, query, args=None):
131132

132133
self._cursor.execute(query)
133134

134-
# Get description from Arrow schema
135-
if self._cursor._current_table is not None:
135+
# Get description from column names and types
136+
if hasattr(self._cursor, "_column_names") and self._cursor._column_names:
136137
self.description = [
137-
(field.name, field.type.to_pandas_dtype(), None, None, None, None, None)
138-
for field in self._cursor._current_table.schema
138+
(name, type_info, None, None, None, None, None)
139+
for name, type_info in zip(
140+
self._cursor._column_names, self._cursor._column_types
141+
)
139142
]
140-
self.rowcount = self._cursor._current_table.num_rows
143+
self.rowcount = (
144+
len(self._cursor._current_table) if self._cursor._current_table else -1
145+
)
141146
else:
142147
self.description = None
143148
self.rowcount = -1
@@ -164,16 +169,23 @@ def executemany(self, query, args):
164169
if m:
165170
q_prefix = m.group(1) % ()
166171
q_values = m.group(2).rstrip()
167-
q_postfix = m.group(3) or ''
168-
assert q_values[0] == '(' and q_values[-1] == ')'
169-
return self._do_execute_many(q_prefix, q_values, q_postfix, args,
170-
self.max_stmt_length,
171-
self._get_db().encoding)
172+
q_postfix = m.group(3) or ""
173+
assert q_values[0] == "(" and q_values[-1] == ")"
174+
return self._do_execute_many(
175+
q_prefix,
176+
q_values,
177+
q_postfix,
178+
args,
179+
self.max_stmt_length,
180+
self._get_db().encoding,
181+
)
172182

173183
self.rowcount = sum(self.execute(query, arg) for arg in args)
174184
return self.rowcount
175185

176-
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
186+
def _do_execute_many(
187+
self, prefix, values, postfix, args, max_stmt_length, encoding
188+
):
177189
conn = self._get_db()
178190
escape = self._escape_args
179191
if isinstance(prefix, str):
@@ -184,18 +196,18 @@ def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encod
184196
args = iter(args)
185197
v = values % escape(next(args), conn)
186198
if isinstance(v, str):
187-
v = v.encode(encoding, 'surrogateescape')
199+
v = v.encode(encoding, "surrogateescape")
188200
sql += v
189201
rows = 0
190202
for arg in args:
191203
v = values % escape(arg, conn)
192204
if isinstance(v, str):
193-
v = v.encode(encoding, 'surrogateescape')
205+
v = v.encode(encoding, "surrogateescape")
194206
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
195207
rows += self.execute(sql + postfix)
196208
sql = prefix
197209
else:
198-
sql += ','.encode(encoding)
210+
sql += ",".encode(encoding)
199211
sql += v
200212
rows += self.execute(sql + postfix)
201213
self.rowcount = rows

Diff for: chdb/state/sqlitelike.py

+155-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import io
21
from typing import Optional, Any
32
from chdb import _chdb
43

@@ -11,6 +10,36 @@
1110
raise ImportError("Failed to import pyarrow") from None
1211

1312

13+
_arrow_format = set({"dataframe", "arrowtable"})
14+
_process_result_format_funs = {
15+
"dataframe": lambda x: to_df(x),
16+
"arrowtable": lambda x: to_arrowTable(x),
17+
}
18+
19+
20+
# return pyarrow table
21+
def to_arrowTable(res):
22+
"""convert res to arrow table"""
23+
# try import pyarrow and pandas, if failed, raise ImportError with suggestion
24+
try:
25+
import pyarrow as pa # noqa
26+
import pandas as pd # noqa
27+
except ImportError as e:
28+
print(f"ImportError: {e}")
29+
print('Please install pyarrow and pandas via "pip install pyarrow pandas"')
30+
raise ImportError("Failed to import pyarrow or pandas") from None
31+
if len(res) == 0:
32+
return pa.Table.from_batches([], schema=pa.schema([]))
33+
return pa.RecordBatchFileReader(res.bytes()).read_all()
34+
35+
36+
# return pandas dataframe
37+
def to_df(r):
38+
"""convert arrow table to Dataframe"""
39+
t = to_arrowTable(r)
40+
return t.to_pandas(use_threads=True)
41+
42+
1443
class Connection:
1544
def __init__(self, connection_string: str):
1645
# print("Connection", connection_string)
@@ -22,7 +51,13 @@ def cursor(self) -> "Cursor":
2251
return self._cursor
2352

2453
def query(self, query: str, format: str = "CSV") -> Any:
25-
return self._conn.query(query, format)
54+
lower_output_format = format.lower()
55+
result_func = _process_result_format_funs.get(lower_output_format, lambda x: x)
56+
if lower_output_format in _arrow_format:
57+
format = "Arrow"
58+
59+
result = self._conn.query(query, format)
60+
return result_func(result)
2661

2762
def close(self) -> None:
2863
# print("close")
@@ -41,17 +76,103 @@ def __init__(self, connection):
4176
def execute(self, query: str) -> None:
4277
self._cursor.execute(query)
4378
result_mv = self._cursor.get_memview()
44-
# print("get_result", result_mv)
4579
if self._cursor.has_error():
4680
raise Exception(self._cursor.error_message())
4781
if self._cursor.data_size() == 0:
4882
self._current_table = None
4983
self._current_row = 0
84+
self._column_names = []
85+
self._column_types = []
5086
return
51-
arrow_data = result_mv.tobytes()
52-
reader = pa.ipc.open_stream(io.BytesIO(arrow_data))
53-
self._current_table = reader.read_all()
54-
self._current_row = 0
87+
88+
# Parse JSON data
89+
json_data = result_mv.tobytes().decode("utf-8")
90+
import json
91+
92+
try:
93+
# First line contains column names
94+
# Second line contains column types
95+
# Following lines contain data
96+
lines = json_data.strip().split("\n")
97+
if len(lines) < 2:
98+
self._current_table = None
99+
self._current_row = 0
100+
self._column_names = []
101+
self._column_types = []
102+
return
103+
104+
self._column_names = json.loads(lines[0])
105+
self._column_types = json.loads(lines[1])
106+
107+
# Convert data rows
108+
rows = []
109+
for line in lines[2:]:
110+
if not line.strip():
111+
continue
112+
row_data = json.loads(line)
113+
converted_row = []
114+
for val, type_info in zip(row_data, self._column_types):
115+
# Handle NULL values first
116+
if val is None:
117+
converted_row.append(None)
118+
continue
119+
120+
# Basic type conversion
121+
try:
122+
if type_info.startswith("Int") or type_info.startswith("UInt"):
123+
converted_row.append(int(val))
124+
elif type_info.startswith("Float"):
125+
converted_row.append(float(val))
126+
elif type_info == "Bool":
127+
converted_row.append(bool(val))
128+
elif type_info == "String" or type_info == "FixedString":
129+
converted_row.append(str(val))
130+
elif type_info.startswith("DateTime"):
131+
from datetime import datetime
132+
133+
# Check if the value is numeric (timestamp)
134+
val_str = str(val)
135+
if val_str.replace(".", "").isdigit():
136+
converted_row.append(datetime.fromtimestamp(float(val)))
137+
else:
138+
# Handle datetime string formats
139+
if "." in val_str: # Has microseconds
140+
converted_row.append(
141+
datetime.strptime(
142+
val_str, "%Y-%m-%d %H:%M:%S.%f"
143+
)
144+
)
145+
else: # No microseconds
146+
converted_row.append(
147+
datetime.strptime(val_str, "%Y-%m-%d %H:%M:%S")
148+
)
149+
elif type_info.startswith("Date"):
150+
from datetime import date, datetime
151+
152+
# Check if the value is numeric (days since epoch)
153+
val_str = str(val)
154+
if val_str.isdigit():
155+
converted_row.append(
156+
date.fromtimestamp(float(val) * 86400)
157+
)
158+
else:
159+
# Handle date string format
160+
converted_row.append(
161+
datetime.strptime(val_str, "%Y-%m-%d").date()
162+
)
163+
else:
164+
# For unsupported types, keep as string
165+
converted_row.append(str(val))
166+
except (ValueError, TypeError):
167+
# If conversion fails, keep original value as string
168+
converted_row.append(str(val))
169+
rows.append(tuple(converted_row))
170+
171+
self._current_table = rows
172+
self._current_row = 0
173+
174+
except json.JSONDecodeError as e:
175+
raise Exception(f"Failed to parse JSON data: {e}")
55176

56177
def commit(self) -> None:
57178
self._cursor.commit()
@@ -60,12 +181,10 @@ def fetchone(self) -> Optional[tuple]:
60181
if not self._current_table or self._current_row >= len(self._current_table):
61182
return None
62183

63-
row_dict = {
64-
col: self._current_table.column(col)[self._current_row].as_py()
65-
for col in self._current_table.column_names
66-
}
184+
# Now self._current_table is a list of row tuples
185+
row = self._current_table[self._current_row]
67186
self._current_row += 1
68-
return tuple(row_dict.values())
187+
return row
69188

70189
def fetchmany(self, size: int = 1) -> tuple:
71190
if not self._current_table:
@@ -99,6 +218,30 @@ def __next__(self) -> tuple:
99218
raise StopIteration
100219
return row
101220

221+
def column_names(self) -> list:
222+
"""Return a list of column names from the last executed query"""
223+
return self._column_names if hasattr(self, "_column_names") else []
224+
225+
def column_types(self) -> list:
226+
"""Return a list of column types from the last executed query"""
227+
return self._column_types if hasattr(self, "_column_types") else []
228+
229+
@property
230+
def description(self) -> list:
231+
"""
232+
Return a description of the columns as per DB-API 2.0
233+
Returns a list of 7-item tuples, each containing:
234+
(name, type_code, display_size, internal_size, precision, scale, null_ok)
235+
where only name and type_code are provided
236+
"""
237+
if not hasattr(self, "_column_names") or not self._column_names:
238+
return []
239+
240+
return [
241+
(name, type_info, None, None, None, None, None)
242+
for name, type_info in zip(self._column_names, self._column_types)
243+
]
244+
102245

103246
def connect(connection_string: str = ":memory:") -> Connection:
104247
"""

Diff for: programs/local/LocalChdb.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,9 @@ void cursor_wrapper::execute(const std::string & query_str)
321321
release_result();
322322
global_query_obj = findQueryableObjFromQuery(query_str);
323323

324-
// Always use Arrow format internally
324+
// Use JSONCompactEachRowWithNamesAndTypes format for better type support
325325
py::gil_scoped_release release;
326-
current_result = query_conn(conn->get_conn(), query_str.c_str(), "ArrowStream");
326+
current_result = query_conn(conn->get_conn(), query_str.c_str(), "JSONCompactEachRowWithNamesAndTypes");
327327
}
328328

329329

0 commit comments

Comments
 (0)