Skip to content

Commit b5012ef

Browse files
committed
refactor: optimize executemany to support both binary and text protocol execution
1 parent d15998d commit b5012ef

File tree

4 files changed

+190
-199
lines changed

4 files changed

+190
-199
lines changed

mariadb/async_cursor.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -205,37 +205,51 @@ async def executemany(self, sql: str, data: Sequence[Union[Sequence[Any], dict]]
205205
if data and len(data) > 0 and not isinstance(data, (list, tuple)):
206206
raise ProgrammingError(f"wrong parameter type")
207207

208-
if (self._stmt is not None):
209-
if (self._stmt.sql != sql):
210-
await self.connection._client.close_prepared_statement(self._stmt)
211-
self._stmt = None
208+
commands = []
209+
if self._force_binary:
210+
if (self._stmt is not None):
211+
if (self._stmt.sql != sql):
212+
await self.connection._client.close_prepared_statement(self._stmt)
213+
self._stmt = None
212214

213-
if (self._stmt is None):
214-
self._stmt = await self.connection._client.prepare_statement(sql)
215+
if (self._stmt is None):
216+
self._stmt = await self.connection._client.prepare_statement(sql)
215217

216-
# Execute with parameters using ExecutePacket
217-
from .impl.message.client.execute_packet import ExecutePacket
218+
from .impl.message.client.execute_packet import ExecutePacket
219+
220+
for params in data:
221+
parameters = None
222+
if params:
223+
parameters = list(params)
224+
225+
if parameters:
226+
if len(parameters) < self._stmt.parameter_count:
227+
raise ProgrammingError(
228+
f"Parameter count mismatch: SQL has {self._stmt.parameter_count} placeholders, "
229+
f"but only {len(parameters)} parameters provided"
230+
)
231+
232+
execute_packet = ExecutePacket(self._stmt.statement_id, parameters, sql)
233+
commands.append(execute_packet)
234+
else:
235+
sql_bytes, param_positions = split_sql_parts(sql)
236+
237+
placeholder_count = len(param_positions) // 2 # Positions come in pairs
238+
239+
for params in data:
240+
parameters = None
241+
if params:
242+
parameters = list(params)
243+
244+
if parameters:
245+
if len(parameters) < placeholder_count:
246+
raise ProgrammingError(
247+
f"Parameter count mismatch: SQL has {placeholder_count} placeholders, "
248+
f"but only {len(parameters)} parameters provided"
249+
)
250+
query_packet = QueryWithParamPacket(sql_bytes, param_positions, parameters)
251+
commands.append(query_packet)
218252

219-
# Execute the statement for each parameter set
220-
commands = []
221-
for params in data:
222-
# Execute with current parameter set
223-
# Convert data to list format for parameter binding
224-
parameters = None
225-
if params:
226-
parameters = list(params)
227-
228-
# Validate parameter count matches placeholders
229-
if parameters:
230-
if len(parameters) < self._stmt.parameter_count:
231-
raise ProgrammingError(
232-
f"Parameter count mismatch: SQL has {self._stmt.parameter_count} placeholders, "
233-
f"but only {len(parameters)} parameters provided"
234-
)
235-
236-
execute_packet = ExecutePacket(self._stmt.statement_id, parameters, sql)
237-
commands.append(execute_packet)
238-
239253
completions = await self.connection._client.execute_many(commands, self._config, True, self._stmt)
240254

241255
# Process the completions - aggregate result sets with compatible metadata

0 commit comments

Comments
 (0)