Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions chdb/dbapi/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(
r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)"
+ r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
+ r"(\(\s*(?:%s|%\(.+\)s|\?)\s*(?:,\s*(?:%s|%\(.+\)s|\?)\s*)*\))"
+ r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
re.IGNORECASE | re.DOTALL,
)
Expand Down Expand Up @@ -99,6 +99,49 @@ def _escape_args(self, args, conn):
# Worst case it will throw a Value error
return conn.escape(args)

def _format_query(self, query, args, conn):
"""Format query with arguments supporting ? and % placeholders."""
if args is None or ('?' not in query and '%' not in query):
return query

escaped_args = self._escape_args(args, conn)
if not isinstance(escaped_args, (tuple, list)):
escaped_args = (escaped_args,)

result = []
arg_index = 0
max_args = len(escaped_args)
i = 0
query_len = len(query)
in_string = False
quote_char = None

while i < query_len:
char = query[i]
if not in_string:
if char in ("'", '"'):
in_string = True
quote_char = char
elif arg_index < max_args:
if char == '?':
result.append(str(escaped_args[arg_index]))
arg_index += 1
i += 1
continue
elif char == '%' and i + 1 < query_len and query[i + 1] == 's':
result.append(str(escaped_args[arg_index]))
arg_index += 1
i += 2
continue
elif char == quote_char and (i == 0 or query[i - 1] != '\\'):
in_string = False
quote_char = None

result.append(char)
i += 1

return ''.join(result)

def mogrify(self, query, args=None):
"""
Returns the exact string that is sent to the database by calling the
Expand All @@ -107,11 +150,7 @@ def mogrify(self, query, args=None):
This method follows the extension to the DB API 2.0 followed by Psycopg.
"""
conn = self._get_db()

if args is not None:
query = query % self._escape_args(args, conn)

return query
return self._format_query(query, args, conn)

def execute(self, query, args=None):
"""Execute a query
Expand All @@ -124,12 +163,11 @@ def execute(self, query, args=None):
:return: Number of affected rows
:rtype: int

If args is a list or tuple, %s can be used as a placeholder in the query.
If args is a list or tuple, ? can be used as a placeholder in the query.
If args is a dict, %(name)s can be used as a placeholder in the query.
Also supports %s placeholder for backward compatibility.
"""
if args is not None:
query = query % self._escape_args(args, self.connection)

query = self._format_query(query, args, self.connection)
self._cursor.execute(query)

# Get description from column names and types
Expand Down Expand Up @@ -187,20 +225,20 @@ def _do_execute_many(
self, prefix, values, postfix, args, max_stmt_length, encoding
):
conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, str):
prefix = prefix.encode(encoding)
if isinstance(postfix, str):
postfix = postfix.encode(encoding)
sql = prefix
args = iter(args)
v = values % escape(next(args), conn)

v = self._format_query(values, next(args), conn)
if isinstance(v, str):
v = v.encode(encoding, "surrogateescape")
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
v = self._format_query(values, arg, conn)
if isinstance(v, str):
v = v.encode(encoding, "surrogateescape")
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
Expand Down
51 changes: 51 additions & 0 deletions tests/test_dbapi_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,57 @@ def test_persistence(self):
row = cur2.fetchone()
self.assertEqual(("he", 32), row)

def test_placeholder1(self):
conn = dbapi.connect(path=test_state_dir)
cur = conn.cursor()

cur.execute("CREATE DATABASE test ENGINE = Atomic;")
cur.execute(
"CREATE TABLE test.users (id UInt64, name String, age UInt32) "
"ENGINE = MergeTree ORDER BY id;"
)

cur.execute("INSERT INTO test.users (id, name, age) VALUES (?, ?, ?)",
(1, 'Alice', 25))

cur.execute("SELECT name, age FROM test.users WHERE id = ? AND age > ?",
(1, 20))
row = cur.fetchone()
self.assertEqual(("Alice", 25), row)

data = [(2, 'Bob', 30), (3, 'Charlie', 35), (4, 'David', 28)]
cur.executemany("INSERT INTO test.users (id, name, age) VALUES (?, ?, ?)",
data)

cur.execute("SELECT COUNT(*) FROM test.users WHERE id > 1")
count = cur.fetchone()[0]
self.assertEqual(3, count)
cur.execute("SELECT name FROM test.users WHERE age = ? ORDER BY id", (30,))
result = cur.fetchone()
self.assertEqual(("Bob",), result)
cur.close()
conn.close()

def test_placeholder2(self):
conn = dbapi.connect(path=test_state_dir)
cur = conn.cursor()

# Create table
cur.execute("CREATE DATABASE compat ENGINE = Atomic;")
cur.execute(
"CREATE TABLE compat.test (id UInt64, value String) "
"ENGINE = MergeTree ORDER BY id;"
)
cur.execute("INSERT INTO compat.test (id, value) VALUES (%s, %s)",
(1, 'test_value'))

cur.execute("SELECT value FROM compat.test")
result = cur.fetchone()
self.assertEqual(("test_value",), result)

cur.close()
conn.close()


if __name__ == "__main__":
unittest.main()
Loading