44import array
55import datetime
66import decimal
7- from typing import TYPE_CHECKING , Any , List , Optional , Union as UnionType
8-
7+ import ipaddress
8+ import uuid
9+ from typing import TYPE_CHECKING , Any , List
10+ import struct
11+ import re
912try :
1013 import numpy
1114 HAS_NUMPY = True
1215except ImportError :
1316 HAS_NUMPY = False
1417
1518from ...client .context import Context
16- from ...string_utils import StringEscaper
1719from ...sql_parser import split_sql_parts
1820from mariadb_shared .constants .STATUS import NO_BACKSLASH_ESCAPES
1921from mariadb_shared .constants .INDICATOR import MrdbIndicator
2022from ..client_message import ClientMessage
2123from ..payload_stream import PayloadStream
2224from ....exceptions import NotSupportedError
23- if TYPE_CHECKING :
24- from ...client .socket .write_stream import BaseWriteStream
2525
2626BINARY_PREFIX : bytes = bytearray (b"_binary'" )
2727QUOTE_BYTE : int = b"'" [0 ]
@@ -33,7 +33,7 @@ class QueryPacket(ClientMessage):
3333 """
3434 Simple query packet for SQL execution without parameters
3535 """
36-
36+
3737 def __init__ (self , sql : str ):
3838 """Initialize COM_QUERY packet with SQL"""
3939 self .sql = sql
@@ -52,11 +52,11 @@ class QueryWithParamPacket(ClientMessage):
5252 """
5353 Parameterized query packet for SQL execution with parameter binding
5454 """
55-
55+
5656 def __init__ (self , sql_bytes : bytes , param_positions : List [int ], parameters : List [Any ]):
5757 """
5858 Initialize COM_QUERY packet with pre-parsed SQL bytes and parameters
59-
59+
6060 Args:
6161 sql_bytes: SQL encoded as UTF-8 bytes
6262 param_positions: Byte positions (start, end) pairs where placeholders are
@@ -70,136 +70,153 @@ def payload(self, context: Context) -> bytes:
7070 """Generate COM_QUERY packet payload with SQL and bound parameters"""
7171 stream = PayloadStream ()
7272 no_backslash_escapes = context .server_status & NO_BACKSLASH_ESCAPES > 0
73-
73+
7474 # Write SQL fragments interleaved with parameters
7575 last_pos = 0
7676 param_idx = 0
77+ params = self .parameters
78+ converter = PARAM_CONVERT_TBL
79+ parts = [b'' ] * (len (self .param_positions ) + 2 + 1 )
7780
78- stream .write_byte (COM_QUERY )
81+ parts [0 ] = b'\x03 '
82+ count = 1
7983 # Iterate through placeholder positions (they come in pairs: start, end)
8084 for i in range (0 , len (self .param_positions ), 2 ):
8185 start_pos = self .param_positions [i ]
8286 end_pos = self .param_positions [i + 1 ]
83-
87+
8488 # Write SQL fragment before this placeholder
8589 if start_pos > last_pos :
86- stream .write_bytes (self .sql_bytes [last_pos :start_pos ])
87-
90+ parts [count ]= self .sql_bytes [last_pos :start_pos ]
91+ count += 1
92+
8893 # Write parameter value
8994 if param_idx < len (self .parameters ):
90- self . _write_parameter_value ( stream , self . parameters [param_idx ], no_backslash_escapes )
95+ parts [ count ] = converter . get ( type ( params [ param_idx ]), lambda v , ctx = None : str ( v ). encode ( 'utf8' ))( params [param_idx ], no_backslash_escapes )
9196 param_idx += 1
97+ count += 1
9298 else :
93- stream .write_string ('NULL' , 'ascii' )
94-
99+ parts [count ] = b'NULL'
100+ count += 1
101+
95102 last_pos = end_pos
96-
103+
97104 # Write remaining SQL after last placeholder
98105 if last_pos < len (self .sql_bytes ):
99- stream .write_bytes (self .sql_bytes [last_pos :])
100-
101- return stream .get_payload ()
102-
103- def _write_parameter_value (self , stream : UnionType ['BaseWriteStream' , PayloadStream ], param : Any , no_backslash_escapes : bool ) -> None :
104- """
105- Write parameter value directly as its string representation
106- (for COM_QUERY, parameters are converted to strings)
107-
108- Args:
109- stream: Stream writer (BaseWriteStream or PayloadStream)
110- param: Parameter value
111- no_backslash_escapes: Whether to use NO_BACKSLASH_ESCAPES mode
112- """
113- if param is None :
114- stream .write_string ('NULL' , 'ascii' )
115- elif isinstance (param , MrdbIndicator ):
116- # Handle MariaDB indicator values
117- if param .indicator == 1 : # NULL
118- stream .write_string ('NULL' , 'ascii' )
119- elif param .indicator == 2 : # DEFAULT
120- stream .write_string ('DEFAULT' , 'ascii' )
121- elif param .indicator == 3 : # IGNORE
122- # Skip this parameter - should be handled at a higher level
123- pass
124- elif param .indicator == 4 : # IGNORE_ROW
125- # Skip entire row - should be handled at a higher level
126- pass
127- else :
128- # Unknown indicator, treat as NULL
129- stream .write_bytes (NULL_BYTES )
130- else :
131- match param :
132- case str ():
133- stream .write_byte (QUOTE_BYTE )
134- stream .write_string (StringEscaper .escape_string (param , no_backslash_escapes ))
135- stream .write_byte (QUOTE_BYTE )
136- case bytes () | bytearray ():
137- stream .write_bytes (BINARY_PREFIX )
138- stream .write_escaped_bytes (param , no_backslash_escapes )
139- stream .write_byte (QUOTE_BYTE )
140- case bool ():
141- # Handle boolean before int/float since bool is a subclass of int in Python
142- stream .write_string ( '1' if param else '0' , 'ascii' )
143- case int ():
144- stream .write_string ( str (param ), 'ascii' )
145- case float ():
146- if repr (param ) in ("nan" , "inf" , "-inf" ):
147- raise NotSupportedError (f"Float value '{ repr (param )} ' is not supported." )
148- stream .write_string ( str (param ), 'ascii' )
149- case datetime .datetime ():
150- # DATETIME: 'YYYY-MM-DD HH:MM:SS.ffffff'
151- if param .microsecond :
152- stream .write_string (f"'{ param .strftime ('%Y-%m-%d %H:%M:%S' )} .{ param .microsecond :06d} '" , 'ascii' )
153- else :
154- stream .write_string (f"'{ param .strftime ('%Y-%m-%d %H:%M:%S' )} '" , 'ascii' )
155- case datetime .date ():
156- # DATE: 'YYYY-MM-DD'
157- stream .write_string (f"'{ param .strftime ('%Y-%m-%d' )} '" , 'ascii' )
158- case datetime .time ():
159- # TIME: 'HH:MM:SS.ffffff'
160- if param .microsecond :
161- stream .write_string (f"'{ param .strftime ('%H:%M:%S' )} .{ param .microsecond :06d} '" , 'ascii' )
162- else :
163- stream .write_string (f"'{ param .strftime ('%H:%M:%S' )} '" , 'ascii' )
164- case datetime .timedelta ():
165- # Convert timedelta to TIME format (can be negative)
166- total_seconds = int (param .total_seconds ())
167- hours , remainder = divmod (abs (total_seconds ), 3600 )
168- minutes , seconds = divmod (remainder , 60 )
169- microseconds = param .microseconds
170-
171- sign = '-' if total_seconds < 0 else ''
172- if microseconds :
173- stream .write_string (f"'{ sign } { hours :02d} :{ minutes :02d} :{ seconds :02d} .{ microseconds :06d} '" , 'ascii' )
174- else :
175- stream .write_string (f"'{ sign } { hours :02d} :{ minutes :02d} :{ seconds :02d} '" , 'ascii' )
176- case decimal .Decimal ():
177- if param .__str__ () in ("NaN" , "sNaN" , "Infinity" , "-Infinity" ):
178- raise NotSupportedError (f"Decimal value '{ param .__str__ ()} ' is not supported." )
179- # DECIMAL/NUMERIC: no quotes needed, just string representation
180- stream .write_string (str (param ), 'ascii' )
181- case array .array () if param .typecode == 'f' :
182- if len (param ) == 0 :
183- stream .write_bytes (NULL_BYTES )
184- return
185- # Float array for VECTOR columns - encode as numpy float32 bytes
186- if HAS_NUMPY :
187- float_bytes = numpy .array (param , numpy .float32 ).tobytes ()
188- else :
189- # Fallback: use array.tobytes() directly
190- float_bytes = param .tobytes ()
191- stream .write_bytes (BINARY_PREFIX )
192- stream .write_escaped_bytes (float_bytes , no_backslash_escapes )
193- stream .write_byte (QUOTE_BYTE )
194- case _:
195- # For other types, convert to string and escape
196- stream .write_byte (QUOTE_BYTE )
197- stream .write_string (StringEscaper .escape_string (str (param ), no_backslash_escapes ))
198- stream .write_byte (QUOTE_BYTE )
106+ parts [count ] = self .sql_bytes [last_pos :]
107+ return b'' . join (parts )
199108
200109 def is_binary (self ) -> bool :
201110 return False
202111
203112 def type (self ) -> str :
204113 return "COM_QUERY"
205-
114+
115+
116+ #### Conversion routines should be moved to a "central" place
117+
118+ def float2bytes (value : float ) -> bytes :
119+ if repr (value ) in ("nan" , "inf" , "-inf" ):
120+ raise NotSupportedError (f"Float value '{ repr (value )} ' is not supported." )
121+ return str (value ).encode ('ascii' )
122+
123+ def decimal2bytes (value : float ) -> bytes :
124+ if value .__str__ () in ("NaN" , "sNaN" , "Infinity" , "-Infinity" ):
125+ raise NotSupportedError (f"Decimal value '{ value .__str__ ()} ' is not supported." )
126+ return str (value ).encode ('ascii' )
127+
128+ _ESCAPE_REGEX = re .compile (r'[\\\'"\0]' )
129+ _ESCAPE_MAP = {'\\ ' : '\\ \\ ' , "'" : "\\ '" , '"' : '\\ "' , '\0 ' : '\\ 0' }
130+
131+ def escape_str (string : str , no_backslash_escapes : bool = False ) -> bytes :
132+ """
133+ Escape a string for SQL statements
134+ """
135+ if no_backslash_escapes :
136+ # When NO_BACKSLASH_ESCAPES is set, single quotes are escaped by doubling them
137+ escaped = string .replace ("'" , "''" )
138+ else :
139+ # Standard escaping: backslash, quote, double quote, zero byte
140+ escaped = _ESCAPE_REGEX .sub (lambda m : _ESCAPE_MAP [m .group (0 )], string )
141+
142+ return b"'" + escaped .encode (encoding = "utf8" ) + b"'"
143+
144+ def timedelta (val : datetime .timedelta ) -> bytes :
145+ total_seconds = int (val .total_seconds ())
146+ is_negative = total_seconds < 0
147+
148+ # Work with absolute values
149+ abs_seconds = abs (total_seconds )
150+ hours = abs_seconds // 3600
151+ minutes = (abs_seconds % 3600 ) // 60
152+ seconds = abs_seconds % 60
153+ microseconds = abs (val .microseconds )
154+
155+ sign = '-' if is_negative else ''
156+ return f"'{ sign } { hours } :{ minutes :02d} :{ seconds :02d} .{ microseconds } '" .encode ('ascii' )
157+
158+ _ESCAPE_BYTES_REGEX = re .compile (rb'[\\\'"\0]' )
159+ _ESCAPE_BYTES_MAP = {b'\\ ' : b'\\ \\ ' , b"'" : b"\\ '" , b'"' : b'\\ "' , b'\0 ' : b'\\ 0' }
160+
161+ def escape_bytes (b : bytes , no_backslash_escapes : bool = False ) -> bytes :
162+ """
163+ Escape a string for SQL statements
164+ """
165+ if no_backslash_escapes :
166+ # When NO_BACKSLASH_ESCAPES is set, single quotes are escaped by doubling them
167+ escaped = b .replace (b"'" , b"''" )
168+ else :
169+ # Standard escaping: backslash, quote, double quote, zero byte
170+ escaped = _ESCAPE_BYTES_REGEX .sub (lambda m : _ESCAPE_BYTES_MAP [m .group (0 )], b )
171+
172+ return b"_binary'" + escaped + b"'"
173+
174+ def float_array_to_bytes (arr : array .array , no_backslash_escapes : bool = False ) -> bytes :
175+ """Convert float array to binary representation for VECTOR columns"""
176+ if len (arr ) == 0 :
177+ return b'NULL'
178+ # Float array for VECTOR columns - encode as numpy float32 bytes
179+ if HAS_NUMPY :
180+ float_bytes = numpy .array (arr , numpy .float32 ).tobytes ()
181+ else :
182+ # Fallback: use array.tobytes() directly
183+ float_bytes = arr .tobytes ()
184+ return escape_bytes (float_bytes , no_backslash_escapes )
185+
186+ def tuple_to_bytes (t : tuple , no_backslash_escapes : bool = False ) -> bytes :
187+ """Convert tuple to bytes - raises error as tuples are not directly supported"""
188+ raise NotSupportedError ("Tuple parameters are not supported. Use individual values or convert to a supported type." )
189+
190+ def indicator_val (v ):
191+ if v .indicator == 1 :
192+ return b'NULL'
193+ elif v .indicator == 2 :
194+ return b'DEFAULT'
195+ elif v .indicator == 3 : # bulk only
196+ pass
197+ elif v .indicator == 4 : # bulk only
198+ pass
199+ else :
200+ return b'NULL'
201+
202+
203+ PARAM_CONVERT_TBL = {
204+ int : lambda v , ctx = None : str (v ).encode ('ascii' ),
205+ float : lambda v , ctx = None : float2bytes (v ),
206+ str : lambda v , ctx : escape_str (v , ctx ),
207+ bytes : lambda v , ctx = None : escape_bytes (v , ctx ),
208+ bytearray : lambda v , ctx = None : escape_bytes (v , ctx ),
209+ decimal .Decimal : lambda v , ctx = None : decimal2bytes (v ),
210+ datetime .date : lambda v , ctx = None : b"'" + str (v ).encode ('ascii' ) + b"'" ,
211+ datetime .datetime : lambda v , ctx = None : b"'" + str (v ).encode ('ascii' ) + b"'" ,
212+ datetime .time : lambda v , ctx = None : b"'" + str (v ).encode ('ascii' ) + b"'" ,
213+ datetime .timedelta : lambda v , ctx = None : timedelta (v ),
214+ type (None ): lambda v , ctx = None : b'NULL' ,
215+ bool : lambda v , ctx = None : b'1' if v else b'0' ,
216+ MrdbIndicator : lambda v , ctx = None : indicator_val (v ),
217+ ipaddress .IPv4Address : lambda v , ctx = None : b"'" + str (v ).encode ('ascii' ) + b"'" ,
218+ ipaddress .IPv6Address : lambda v , ctx = None : b"'" + str (v ).encode ('ascii' ) + b"'" ,
219+ uuid .UUID : lambda v , ctx = None : b"'" + str (v ).encode ('ascii' ) + b"'" ,
220+ array .array : lambda v , ctx = None : float_array_to_bytes (v , ctx ),
221+ tuple : lambda v , ctx = None : tuple_to_bytes (v , ctx ),
222+ }
0 commit comments