Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 53 additions & 12 deletions chdb/dbapi/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(
r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)"
+ r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
+ r"(\(\s*(?:%s|%\(.+\)s|\?)\s*(?:,\s*(?:%s|%\(.+\)s|\?)\s*)*\))"
+ r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
re.IGNORECASE | re.DOTALL,
)
Expand Down Expand Up @@ -99,6 +99,42 @@ def _escape_args(self, args, conn):
# Worst case it will throw a Value error
return conn.escape(args)

def _format_query(self, query, args, conn):
"""Format query with arguments supporting ? and % placeholders."""
if args is None:
return query

if isinstance(args, (tuple, list)) and '?' in query:
escaped_args = self._escape_args(args, conn)
result = []
arg_index = 0
in_string = False
quote_char = None

for i, char in enumerate(query):
# Track string literals to avoid replacing ? inside strings
if not in_string and char in ("'", '"'):
in_string = True
quote_char = char
elif in_string and char == quote_char:
# Check if it's an escaped quote
if i == 0 or query[i-1] != '\\':
in_string = False
quote_char = None

# Only replace ? outside of string literals
if char == '?' and not in_string and arg_index < len(escaped_args):
result.append(str(escaped_args[arg_index]))
arg_index += 1
else:
result.append(char)

return ''.join(result)
elif '%' in query:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit unrigorous

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The %s placeholder retained the old logic (without considering escaping, in strings, etc.). Has been fixed to be the same as the handling for the ?.

return query % self._escape_args(args, conn)

return query

def mogrify(self, query, args=None):
"""
Returns the exact string that is sent to the database by calling the
Expand All @@ -107,11 +143,7 @@ def mogrify(self, query, args=None):
This method follows the extension to the DB API 2.0 followed by Psycopg.
"""
conn = self._get_db()

if args is not None:
query = query % self._escape_args(args, conn)

return query
return self._format_query(query, args, conn)

def execute(self, query, args=None):
"""Execute a query
Expand All @@ -124,12 +156,11 @@ def execute(self, query, args=None):
:return: Number of affected rows
:rtype: int

If args is a list or tuple, %s can be used as a placeholder in the query.
If args is a list or tuple, ? can be used as a placeholder in the query.
If args is a dict, %(name)s can be used as a placeholder in the query.
Also supports %s placeholder for backward compatibility.
"""
if args is not None:
query = query % self._escape_args(args, self.connection)

query = self._format_query(query, args, self.connection)
self._cursor.execute(query)

# Get description from column names and types
Expand Down Expand Up @@ -194,13 +225,23 @@ def _do_execute_many(
postfix = postfix.encode(encoding)
sql = prefix
args = iter(args)
v = values % escape(next(args), conn)

first_arg = next(args)
if '?' in values:
v = self._format_query(values, first_arg, conn)
else:
v = values % escape(first_arg, conn)

if isinstance(v, str):
v = v.encode(encoding, "surrogateescape")
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
if '?' in values:
v = self._format_query(values, arg, conn)
else:
v = values % escape(arg, conn)

if isinstance(v, str):
v = v.encode(encoding, "surrogateescape")
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_dbapi_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,59 @@ def test_persistence(self):
row = cur2.fetchone()
self.assertEqual(("he", 32), row)

def test_placeholder1(self):
conn = dbapi.connect(path=test_state_dir)
cur = conn.cursor()

cur.execute("CREATE DATABASE test ENGINE = Atomic;")
cur.execute(
"CREATE TABLE test.users (id UInt64, name String, age UInt32) "
"ENGINE = MergeTree ORDER BY id;"
)

cur.execute("INSERT INTO test.users (id, name, age) VALUES (?, ?, ?)",
(1, 'Alice', 25))

cur.execute("SELECT name, age FROM test.users WHERE id = ? AND age > ?",
(1, 20))
row = cur.fetchone()
self.assertEqual(("Alice", 25), row)

data = [(2, 'Bob', 30), (3, 'Charlie', 35), (4, 'David', 28)]
cur.executemany("INSERT INTO test.users (id, name, age) VALUES (?, ?, ?)",
data)

cur.execute("SELECT COUNT(*) FROM test.users WHERE id > 1")
count = cur.fetchone()[0]
self.assertEqual(3, count)
cur.execute("SELECT name FROM test.users WHERE age = ? ORDER BY id", (30,))
result = cur.fetchone()
self.assertEqual(("Bob",), result)
cur.close()
conn.close()

def test_placeholder2(self):
conn = dbapi.connect(path=test_state_dir)
cur = conn.cursor()

# Create table
cur.execute("CREATE DATABASE compat ENGINE = Atomic;")
cur.execute(
"CREATE TABLE compat.test (id UInt64, value String) "
"ENGINE = MergeTree ORDER BY id;"
)

# Test %s placeholders still work
cur.execute("INSERT INTO compat.test (id, value) VALUES (%s, %s)",
(1, 'test_value'))

cur.execute("SELECT value FROM compat.test")
result = cur.fetchone()
self.assertEqual(("test_value",), result)

cur.close()
conn.close()


if __name__ == "__main__":
unittest.main()
Loading