Skip to content

Commit d15998d

Browse files
committed
refactor: optimize executemany to use prepared statements and batch execution
1 parent 9b249b1 commit d15998d

File tree

9 files changed

+284
-280
lines changed

9 files changed

+284
-280
lines changed

mariadb/async_cursor.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -176,27 +176,6 @@ async def execute(self, sql: str, data: Optional[Union[Sequence[Any], dict]] = N
176176
)
177177

178178
async def executemany(self, sql: str, data: Sequence[Union[Sequence[Any], dict]], buffered: Optional[bool] = None) -> None:
179-
"""
180-
Execute a SQL statement multiple times with different parameter sets
181-
182-
More efficient than calling execute() multiple times as it can
183-
batch operations and reduce round trips.
184-
185-
Args:
186-
sql: SQL statement to execute (typically INSERT, UPDATE, DELETE)
187-
data: Sequence of parameter sequences, one for each execution
188-
buffered: Override cursor's buffered setting
189-
190-
Raises:
191-
ProgrammingError: If cursor is closed
192-
DatabaseError: If execution fails
193-
194-
Example:
195-
>>> await cursor.executemany(
196-
... "INSERT INTO users VALUES (?, ?)",
197-
... [(1, 'Alice'), (2, 'Bob'), (3, 'Charlie')]
198-
... )
199-
"""
200179
"""
201180
Execute a statement multiple times with different parameter sets
202181
@@ -211,38 +190,34 @@ async def executemany(self, sql: str, data: Sequence[Union[Sequence[Any], dict]]
211190
if not isinstance(sql, str):
212191
raise TypeError("SQL statement must be a string")
213192

214-
# Check if data is None or not an array-like type
215-
if data is None or not hasattr(data, '__iter__') or isinstance(data, (str, bytes)):
216-
raise ProgrammingError("No data provided")
217-
218193
# Consume any pending streaming results before executing new query
219194
if self._result is not None and self._result.streaming():
220195
await self._result.fetch_remaining()
221196

222-
# If data is an empty list/tuple, return early with rowcount=0
223-
if not data:
224-
self._rowcount = 0
225-
return
197+
# Check if data is None or not an array-like type
198+
if data is None or not hasattr(data, '__iter__') or isinstance(data, (str, bytes)):
199+
raise ProgrammingError("No data provided")
226200

227201
# Reset result state
228202
self._result = None
229203

230204
try:
231-
completions = list()
232205
if data and len(data) > 0 and not isinstance(data, (list, tuple)):
233206
raise ProgrammingError(f"wrong parameter type")
234207

235-
# Pre-parse SQL once for optimization (avoid re-parsing for each row)
236-
sql_bytes = None
237-
param_positions = None
238-
param_names = None
239-
placeholder_count = 0
240-
241-
# For positional parameters, parse SQL once
242-
sql_bytes, param_positions = split_sql_parts(sql)
243-
placeholder_count = len(param_positions) // 2
244-
208+
if (self._stmt is not None):
209+
if (self._stmt.sql != sql):
210+
await self.connection._client.close_prepared_statement(self._stmt)
211+
self._stmt = None
212+
213+
if (self._stmt is None):
214+
self._stmt = await self.connection._client.prepare_statement(sql)
215+
216+
# Execute with parameters using ExecutePacket
217+
from .impl.message.client.execute_packet import ExecutePacket
218+
245219
# Execute the statement for each parameter set
220+
commands = []
246221
for params in data:
247222
# Execute with current parameter set
248223
# Convert data to list format for parameter binding
@@ -252,18 +227,16 @@ async def executemany(self, sql: str, data: Sequence[Union[Sequence[Any], dict]]
252227

253228
# Validate parameter count matches placeholders
254229
if parameters:
255-
if len(parameters) < placeholder_count:
230+
if len(parameters) < self._stmt.parameter_count:
256231
raise ProgrammingError(
257-
f"Parameter count mismatch: SQL has {placeholder_count} placeholders, "
232+
f"Parameter count mismatch: SQL has {self._stmt.parameter_count} placeholders, "
258233
f"but only {len(parameters)} parameters provided"
259234
)
260235

261-
# Create query packet and execute with bytes
262-
query_packet = QueryWithParamPacket(sql_bytes, param_positions, parameters)
263-
# Use provided buffered parameter or fall back to cursor default
264-
effective_buffered = buffered if buffered is not None else self._buffered
265-
compl = await self.connection._client.execute(query_packet, self._config, effective_buffered)
266-
completions.extend(compl)
236+
execute_packet = ExecutePacket(self._stmt.statement_id, parameters, sql)
237+
commands.append(execute_packet)
238+
239+
completions = await self.connection._client.execute_many(commands, self._config, True, self._stmt)
267240

268241
# Process the completions - aggregate result sets with compatible metadata
269242
self._process_executemany_completions(completions)
@@ -276,6 +249,7 @@ async def executemany(self, sql: str, data: Sequence[Union[Sequence[Any], dict]]
276249
errno=2013,
277250
sql_state='HY000'
278251
)
252+
279253

280254
# =========================================================================
281255
# Result Fetching Methods

mariadb/base_cursor.py

Lines changed: 18 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ def __init__(self, connection: TConnection, **kwargs):
6767
self._exception_factory = ExceptionFactory()
6868
self._buffered: bool = bool(kwargs.pop('buffered', True))
6969
self._result: Optional[TResult] = None
70-
self._executemany_mode: bool = False
71-
self._executemany_rowcount: int = 0
72-
self._executemany_lastrowid: Optional[int] = None
7370
self._force_binary: bool = False
7471
self._stmt: Optional[PrepareStmtPacket] = None
7572
if kwargs:
@@ -91,9 +88,6 @@ def __init__(self, connection: TConnection, **kwargs):
9188
@property
9289
def rowcount(self) -> int:
9390
"""Get the number of rows (read-only property)"""
94-
# For executemany, return the aggregated rowcount
95-
if self._executemany_mode:
96-
return self._executemany_rowcount
9791
# Get the current completion
9892
if self._completion_index < len(self._completions):
9993
completion = self._completions[self._completion_index]
@@ -147,8 +141,6 @@ def sp_outparams(self) -> bool:
147141
def lastrowid(self) -> Optional[int]:
148142
"""Get the last insert ID from the current completion"""
149143
# For executemany, return the last insert ID from all executions
150-
if self._executemany_mode:
151-
return self._executemany_lastrowid
152144
if self._completion_index < len(self._completions):
153145
completion = self._completions[self._completion_index]
154146
return completion.insert_id or None
@@ -287,11 +279,6 @@ def _process_completions(self, completions: List[Any]) -> None:
287279
Args:
288280
completions: List of completion objects
289281
"""
290-
# Reset executemany mode for regular execute
291-
self._executemany_mode = False
292-
self._executemany_rowcount = 0
293-
self._executemany_lastrowid = None
294-
295282
# Store all completions for nextset() functionality
296283
self._completions = completions
297284
self._completion_index = 0
@@ -321,53 +308,33 @@ def _process_rows_set_completion(self, result_set: Result) -> None:
321308
"""
322309
self._result = result_set
323310

324-
def _process_executemany_completions(self, completions: List[Any]) -> None:
311+
def _process_executemany_completions(self, completions: List[List[Completion]]) -> None:
325312
"""
326313
Process completions from executemany - aggregate all result sets.
327314
Since executemany runs the same query multiple times, all completions have identical metadata.
328-
"""
329-
# Enable executemany mode and calculate aggregated values
330-
self._executemany_mode = True
331-
self._executemany_rowcount = 0
332-
self._executemany_lastrowid = None
333-
334-
# Calculate total affected rows and last insert ID from all completions
335-
for c in completions:
336-
if c.affected_rows >= 0:
337-
self._executemany_rowcount += c.affected_rows
338-
if c.insert_id is not None and c.insert_id > 0:
339-
self._executemany_lastrowid = c.insert_id
340315
316+
Args:
317+
completions: List[List[Completion]] - one list per executed message
318+
"""
341319
if not completions:
342320
self._result = None
343321
return
344-
345-
# Find completions with result sets
346-
result_set_completions = [c for c in completions if c.has_result_set()]
347-
348-
if not result_set_completions:
349-
# No result sets - just update counts (e.g., INSERT/UPDATE/DELETE)
350-
# The aggregated values are already set above
322+
323+
firstCompletion = completions[0]
324+
if not firstCompletion:
351325
self._result = None
352326
return
353-
354-
# All completions have identical metadata - use the first one
355-
first_rs = result_set_completions[0].get_result_set()
356-
first_columns = first_rs.columns
357-
358-
# Aggregate all rows from all result sets
359-
aggregated_rows = []
360-
for completion in result_set_completions:
361-
rs = completion.get_result_set()
362-
rows = rs.rows if hasattr(rs, 'rows') else rs.get('rows', [])
363-
aggregated_rows.extend(rows)
364-
365-
# Build result from aggregated data (description will be computed on-demand)
366-
self._result = self._create_complete_result(
367-
columns=first_columns,
368-
column_count=len(first_columns),
369-
rows=aggregated_rows
370-
)
327+
328+
for u in range(1, len(completions)):
329+
unit_completions = completions[u]
330+
for i, c in enumerate(unit_completions):
331+
if c.affected_rows >= 0:
332+
firstCompletion[i].affected_rows += c.affected_rows
333+
if c.insert_id is not None and c.insert_id > 0:
334+
firstCompletion[i].insert_id = c.insert_id
335+
if c.has_result_set():
336+
firstCompletion[i].result_set.rows.extend(c.result_set.rows)
337+
self._process_completions(firstCompletion)
371338

372339
def _build_description(self, columns: List[ColumnDefinitionPacket]) -> Optional[tuple]:
373340
"""Build cursor description tuple from column definitions"""

0 commit comments

Comments
 (0)