Skip to content

Commit f803cc8

Browse files
9EOR9rusher
authored andcommitted
refactor: optimize query packet parameter conversion
1 parent a7563e6 commit f803cc8

File tree

1 file changed

+135
-118
lines changed

1 file changed

+135
-118
lines changed

mariadb/impl/message/client/query_packet.py

Lines changed: 135 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,24 @@
44
import array
55
import datetime
66
import decimal
7-
from typing import TYPE_CHECKING, Any, List, Optional, Union as UnionType
8-
7+
import ipaddress
8+
import uuid
9+
from typing import TYPE_CHECKING, Any, List
10+
import struct
11+
import re
912
try:
1013
import numpy
1114
HAS_NUMPY = True
1215
except ImportError:
1316
HAS_NUMPY = False
1417

1518
from ...client.context import Context
16-
from ...string_utils import StringEscaper
1719
from ...sql_parser import split_sql_parts
1820
from mariadb_shared.constants.STATUS import NO_BACKSLASH_ESCAPES
1921
from mariadb_shared.constants.INDICATOR import MrdbIndicator
2022
from ..client_message import ClientMessage
2123
from ..payload_stream import PayloadStream
2224
from ....exceptions import NotSupportedError
23-
if TYPE_CHECKING:
24-
from ...client.socket.write_stream import BaseWriteStream
2525

2626
BINARY_PREFIX: bytes = bytearray(b"_binary'")
2727
QUOTE_BYTE: int = b"'"[0]
@@ -33,7 +33,7 @@ class QueryPacket(ClientMessage):
3333
"""
3434
Simple query packet for SQL execution without parameters
3535
"""
36-
36+
3737
def __init__(self, sql: str):
3838
"""Initialize COM_QUERY packet with SQL"""
3939
self.sql = sql
@@ -52,11 +52,11 @@ class QueryWithParamPacket(ClientMessage):
5252
"""
5353
Parameterized query packet for SQL execution with parameter binding
5454
"""
55-
55+
5656
def __init__(self, sql_bytes: bytes, param_positions: List[int], parameters: List[Any]):
5757
"""
5858
Initialize COM_QUERY packet with pre-parsed SQL bytes and parameters
59-
59+
6060
Args:
6161
sql_bytes: SQL encoded as UTF-8 bytes
6262
param_positions: Byte positions (start, end) pairs where placeholders are
@@ -70,136 +70,153 @@ def payload(self, context: Context) -> bytes:
7070
"""Generate COM_QUERY packet payload with SQL and bound parameters"""
7171
stream = PayloadStream()
7272
no_backslash_escapes = context.server_status & NO_BACKSLASH_ESCAPES > 0
73-
73+
7474
# Write SQL fragments interleaved with parameters
7575
last_pos = 0
7676
param_idx = 0
77+
params = self.parameters
78+
converter = PARAM_CONVERT_TBL
79+
parts = [b''] * (len(self.param_positions) + 2 + 1)
7780

78-
stream.write_byte(COM_QUERY)
81+
parts[0] = b'\x03'
82+
count= 1
7983
# Iterate through placeholder positions (they come in pairs: start, end)
8084
for i in range(0, len(self.param_positions), 2):
8185
start_pos = self.param_positions[i]
8286
end_pos = self.param_positions[i + 1]
83-
87+
8488
# Write SQL fragment before this placeholder
8589
if start_pos > last_pos:
86-
stream.write_bytes(self.sql_bytes[last_pos:start_pos])
87-
90+
parts[count]= self.sql_bytes[last_pos:start_pos]
91+
count += 1
92+
8893
# Write parameter value
8994
if param_idx < len(self.parameters):
90-
self._write_parameter_value(stream, self.parameters[param_idx], no_backslash_escapes)
95+
parts[count] = converter.get(type(params[param_idx]), lambda v, ctx=None: str(v).encode('utf8'))(params[param_idx], no_backslash_escapes)
9196
param_idx += 1
97+
count += 1
9298
else:
93-
stream.write_string('NULL', 'ascii')
94-
99+
parts[count] = b'NULL'
100+
count += 1
101+
95102
last_pos = end_pos
96-
103+
97104
# Write remaining SQL after last placeholder
98105
if last_pos < len(self.sql_bytes):
99-
stream.write_bytes(self.sql_bytes[last_pos:])
100-
101-
return stream.get_payload()
102-
103-
def _write_parameter_value(self, stream: UnionType['BaseWriteStream', PayloadStream], param: Any, no_backslash_escapes: bool) -> None:
104-
"""
105-
Write parameter value directly as its string representation
106-
(for COM_QUERY, parameters are converted to strings)
107-
108-
Args:
109-
stream: Stream writer (BaseWriteStream or PayloadStream)
110-
param: Parameter value
111-
no_backslash_escapes: Whether to use NO_BACKSLASH_ESCAPES mode
112-
"""
113-
if param is None:
114-
stream.write_string('NULL', 'ascii')
115-
elif isinstance(param, MrdbIndicator):
116-
# Handle MariaDB indicator values
117-
if param.indicator == 1: # NULL
118-
stream.write_string('NULL', 'ascii')
119-
elif param.indicator == 2: # DEFAULT
120-
stream.write_string('DEFAULT', 'ascii')
121-
elif param.indicator == 3: # IGNORE
122-
# Skip this parameter - should be handled at a higher level
123-
pass
124-
elif param.indicator == 4: # IGNORE_ROW
125-
# Skip entire row - should be handled at a higher level
126-
pass
127-
else:
128-
# Unknown indicator, treat as NULL
129-
stream.write_bytes(NULL_BYTES)
130-
else:
131-
match param:
132-
case str():
133-
stream.write_byte(QUOTE_BYTE)
134-
stream.write_string(StringEscaper.escape_string(param, no_backslash_escapes))
135-
stream.write_byte(QUOTE_BYTE)
136-
case bytes() | bytearray():
137-
stream.write_bytes(BINARY_PREFIX)
138-
stream.write_escaped_bytes(param, no_backslash_escapes)
139-
stream.write_byte(QUOTE_BYTE)
140-
case bool():
141-
# Handle boolean before int/float since bool is a subclass of int in Python
142-
stream.write_string( '1' if param else '0', 'ascii')
143-
case int():
144-
stream.write_string( str(param), 'ascii')
145-
case float():
146-
if repr(param) in ("nan", "inf", "-inf"):
147-
raise NotSupportedError(f"Float value '{repr(param)}' is not supported.")
148-
stream.write_string( str(param), 'ascii')
149-
case datetime.datetime():
150-
# DATETIME: 'YYYY-MM-DD HH:MM:SS.ffffff'
151-
if param.microsecond:
152-
stream.write_string(f"'{param.strftime('%Y-%m-%d %H:%M:%S')}.{param.microsecond:06d}'", 'ascii')
153-
else:
154-
stream.write_string(f"'{param.strftime('%Y-%m-%d %H:%M:%S')}'", 'ascii')
155-
case datetime.date():
156-
# DATE: 'YYYY-MM-DD'
157-
stream.write_string(f"'{param.strftime('%Y-%m-%d')}'", 'ascii')
158-
case datetime.time():
159-
# TIME: 'HH:MM:SS.ffffff'
160-
if param.microsecond:
161-
stream.write_string(f"'{param.strftime('%H:%M:%S')}.{param.microsecond:06d}'", 'ascii')
162-
else:
163-
stream.write_string(f"'{param.strftime('%H:%M:%S')}'", 'ascii')
164-
case datetime.timedelta():
165-
# Convert timedelta to TIME format (can be negative)
166-
total_seconds = int(param.total_seconds())
167-
hours, remainder = divmod(abs(total_seconds), 3600)
168-
minutes, seconds = divmod(remainder, 60)
169-
microseconds = param.microseconds
170-
171-
sign = '-' if total_seconds < 0 else ''
172-
if microseconds:
173-
stream.write_string(f"'{sign}{hours:02d}:{minutes:02d}:{seconds:02d}.{microseconds:06d}'", 'ascii')
174-
else:
175-
stream.write_string(f"'{sign}{hours:02d}:{minutes:02d}:{seconds:02d}'", 'ascii')
176-
case decimal.Decimal():
177-
if param.__str__() in ("NaN", "sNaN", "Infinity", "-Infinity"):
178-
raise NotSupportedError(f"Decimal value '{param.__str__()}' is not supported.")
179-
# DECIMAL/NUMERIC: no quotes needed, just string representation
180-
stream.write_string(str(param), 'ascii')
181-
case array.array() if param.typecode == 'f':
182-
if len(param) == 0:
183-
stream.write_bytes(NULL_BYTES)
184-
return
185-
# Float array for VECTOR columns - encode as numpy float32 bytes
186-
if HAS_NUMPY:
187-
float_bytes = numpy.array(param, numpy.float32).tobytes()
188-
else:
189-
# Fallback: use array.tobytes() directly
190-
float_bytes = param.tobytes()
191-
stream.write_bytes(BINARY_PREFIX)
192-
stream.write_escaped_bytes(float_bytes, no_backslash_escapes)
193-
stream.write_byte(QUOTE_BYTE)
194-
case _:
195-
# For other types, convert to string and escape
196-
stream.write_byte(QUOTE_BYTE)
197-
stream.write_string(StringEscaper.escape_string(str(param), no_backslash_escapes))
198-
stream.write_byte(QUOTE_BYTE)
106+
parts[count] = self.sql_bytes[last_pos:]
107+
return b'' . join(parts)
199108

200109
def is_binary(self) -> bool:
201110
return False
202111

203112
def type(self) -> str:
204113
return "COM_QUERY"
205-
114+
115+
116+
#### Conversion routines should be moved to a "central" place
117+
118+
def float2bytes(value: float) -> bytes:
119+
if repr(value) in ("nan", "inf", "-inf"):
120+
raise NotSupportedError(f"Float value '{repr(value)}' is not supported.")
121+
return str(value).encode('ascii')
122+
123+
def decimal2bytes(value: float) -> bytes:
124+
if value.__str__() in ("NaN", "sNaN", "Infinity", "-Infinity"):
125+
raise NotSupportedError(f"Decimal value '{value.__str__()}' is not supported.")
126+
return str(value).encode('ascii')
127+
128+
_ESCAPE_REGEX = re.compile(r'[\\\'"\0]')
129+
_ESCAPE_MAP = {'\\': '\\\\', "'": "\\'", '"': '\\"', '\0': '\\0'}
130+
131+
def escape_str(string: str, no_backslash_escapes: bool = False) -> bytes:
132+
"""
133+
Escape a string for SQL statements
134+
"""
135+
if no_backslash_escapes:
136+
# When NO_BACKSLASH_ESCAPES is set, single quotes are escaped by doubling them
137+
escaped = string.replace("'", "''")
138+
else:
139+
# Standard escaping: backslash, quote, double quote, zero byte
140+
escaped = _ESCAPE_REGEX.sub(lambda m: _ESCAPE_MAP[m.group(0)], string)
141+
142+
return b"'" + escaped.encode(encoding="utf8") + b"'"
143+
144+
def timedelta(val: datetime.timedelta) -> bytes:
145+
total_seconds = int(val.total_seconds())
146+
is_negative = total_seconds < 0
147+
148+
# Work with absolute values
149+
abs_seconds = abs(total_seconds)
150+
hours = abs_seconds // 3600
151+
minutes = (abs_seconds % 3600) // 60
152+
seconds = abs_seconds % 60
153+
microseconds = abs(val.microseconds)
154+
155+
sign = '-' if is_negative else ''
156+
return f"'{sign}{hours}:{minutes:02d}:{seconds:02d}.{microseconds}'".encode('ascii')
157+
158+
_ESCAPE_BYTES_REGEX = re.compile(rb'[\\\'"\0]')
159+
_ESCAPE_BYTES_MAP = {b'\\': b'\\\\', b"'": b"\\'", b'"': b'\\"', b'\0': b'\\0'}
160+
161+
def escape_bytes(b : bytes, no_backslash_escapes: bool = False) -> bytes:
162+
"""
163+
Escape a string for SQL statements
164+
"""
165+
if no_backslash_escapes:
166+
# When NO_BACKSLASH_ESCAPES is set, single quotes are escaped by doubling them
167+
escaped = b.replace(b"'", b"''")
168+
else:
169+
# Standard escaping: backslash, quote, double quote, zero byte
170+
escaped = _ESCAPE_BYTES_REGEX.sub(lambda m: _ESCAPE_BYTES_MAP[m.group(0)], b)
171+
172+
return b"_binary'" + escaped + b"'"
173+
174+
def float_array_to_bytes(arr: array.array, no_backslash_escapes: bool = False) -> bytes:
175+
"""Convert float array to binary representation for VECTOR columns"""
176+
if len(arr) == 0:
177+
return b'NULL'
178+
# Float array for VECTOR columns - encode as numpy float32 bytes
179+
if HAS_NUMPY:
180+
float_bytes = numpy.array(arr, numpy.float32).tobytes()
181+
else:
182+
# Fallback: use array.tobytes() directly
183+
float_bytes = arr.tobytes()
184+
return escape_bytes(float_bytes, no_backslash_escapes)
185+
186+
def tuple_to_bytes(t: tuple, no_backslash_escapes: bool = False) -> bytes:
187+
"""Convert tuple to bytes - raises error as tuples are not directly supported"""
188+
raise NotSupportedError("Tuple parameters are not supported. Use individual values or convert to a supported type.")
189+
190+
def indicator_val(v):
191+
if v.indicator == 1:
192+
return b'NULL'
193+
elif v.indicator == 2:
194+
return b'DEFAULT'
195+
elif v.indicator == 3: # bulk only
196+
pass
197+
elif v.indicator == 4: # bulk only
198+
pass
199+
else:
200+
return b'NULL'
201+
202+
203+
PARAM_CONVERT_TBL = {
204+
int: lambda v, ctx= None: str(v).encode('ascii'),
205+
float: lambda v, ctx= None: float2bytes(v),
206+
str: lambda v, ctx: escape_str(v, ctx),
207+
bytes: lambda v, ctx= None: escape_bytes(v, ctx),
208+
bytearray: lambda v, ctx= None: escape_bytes(v, ctx),
209+
decimal.Decimal: lambda v, ctx=None: decimal2bytes(v),
210+
datetime.date: lambda v, ctx=None: b"'" + str(v).encode('ascii') + b"'",
211+
datetime.datetime: lambda v, ctx=None: b"'" + str(v).encode('ascii') + b"'",
212+
datetime.time: lambda v, ctx=None: b"'" + str(v).encode('ascii') + b"'",
213+
datetime.timedelta: lambda v, ctx=None: timedelta(v),
214+
type(None): lambda v, ctx= None: b'NULL',
215+
bool: lambda v, ctx= None: b'1' if v else b'0',
216+
MrdbIndicator: lambda v, ctx=None: indicator_val(v),
217+
ipaddress.IPv4Address: lambda v, ctx=None: b"'" +str(v).encode('ascii') + b"'",
218+
ipaddress.IPv6Address: lambda v, ctx=None: b"'" +str(v).encode('ascii') + b"'",
219+
uuid.UUID: lambda v, ctx=None: b"'" +str(v).encode('ascii') + b"'",
220+
array.array: lambda v, ctx=None: float_array_to_bytes(v, ctx),
221+
tuple: lambda v, ctx=None: tuple_to_bytes(v, ctx),
222+
}

0 commit comments

Comments
 (0)