Skip to content

Commit 6584172

Browse files
committed
refactor: optimize server packet parsing with pre-compiled structs and fast paths
1 parent 25b3174 commit 6584172

File tree

4 files changed

+89
-78
lines changed

4 files changed

+89
-78
lines changed

mariadb/impl/message/server/column_definition_packet.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,29 @@
99

1010
import struct
1111
from typing import TYPE_CHECKING, Optional, Tuple
12-
# No longer need PacketBuffer import
12+
1313
if TYPE_CHECKING:
1414
from ...client.context import Context
1515

16+
# Pre-compile struct formats for faster unpacking
17+
_STRUCT_UINT16 = struct.Struct('<H')
18+
_STRUCT_FIXED_FIELDS = struct.Struct('<HIBHB') # charset(H), column_length(I), type(B), flags(H), decimals(B)
19+
1620

1721
def read_small_length_encoded_bytes(data: memoryview, pos: int) -> Tuple[bytes, int]:
1822
"""Read length-encoded bytes and advance position"""
1923
length = data[pos]
2024
pos += 1
2125

22-
if length >= 251:
23-
length = struct.unpack('<H', data[pos:pos+2])[0]
24-
pos += 2
25-
26-
result = data[pos:pos+length].tobytes()
26+
# Fast path: most lengths are < 251
27+
if length < 251:
28+
result = bytes(data[pos:pos+length])
29+
return result, pos + length
30+
31+
# Slow path: 2-byte length
32+
length = _STRUCT_UINT16.unpack_from(data, pos)[0]
33+
pos += 2
34+
result = bytes(data[pos:pos+length])
2735
return result, pos + length
2836

2937

@@ -128,7 +136,7 @@ def org_name(self) -> str:
128136

129137
@staticmethod
130138
def decode(data: memoryview, context: 'Context') -> 'ColumnDefinitionPacket':
131-
"""Decode column definition packet from bytearray with context"""
139+
"""Decode column definition packet"""
132140

133141
pos = 0
134142

@@ -139,35 +147,35 @@ def decode(data: memoryview, context: 'Context') -> 'ColumnDefinitionPacket':
139147
name_bytes, pos = read_small_length_encoded_bytes(data, pos)
140148
org_name_bytes, pos = read_small_length_encoded_bytes(data, pos)
141149

142-
# Handle extended info only if EXTENDED_METADATA capability is enabled
150+
# Fast path: no extended metadata (most common case)
143151
ext_type_name = None
144152
ext_type_format = None
145153
special_format = False
146-
147-
# Check if we have the length field (0x0C) or extended metadata
154+
148155
if context.hasExtendedMetadata():
149-
# Has extended info - read length-encoded buffer
150156
ext_length = data[pos]
151157
pos += 1
152-
ext_end = pos + ext_length
153-
while pos < ext_end and pos < len(data):
154-
ext_type = data[pos]
155-
pos += 1
158+
159+
if ext_length > 0:
156160
special_format = True
157-
if ext_type == 0:
158-
# Extended type name
159-
ext_type_name, pos = read_small_length_encoded_bytes(data, pos)
160-
elif ext_type == 1:
161-
# Extended type format
162-
ext_type_format, pos = read_small_length_encoded_bytes(data, pos)
163-
else:
164-
# Skip unknown extended data
165-
_, pos = read_small_length_encoded_bytes(data, pos)
161+
ext_end = pos + ext_length
162+
163+
while pos < ext_end:
164+
ext_type = data[pos]
165+
pos += 1
166+
167+
if ext_type == 0:
168+
ext_type_name, pos = read_small_length_encoded_bytes(data, pos)
169+
elif ext_type == 1:
170+
ext_type_format, pos = read_small_length_encoded_bytes(data, pos)
171+
else:
172+
# Skip unknown extended data
173+
_, pos = read_small_length_encoded_bytes(data, pos)
166174

167-
# Skip length field (always 0x0C = 12 for fixed fields)
175+
# Skip length field (0x0C) and unpack fixed fields using pre-compiled struct
168176
pos += 1
169-
charset, column_length, type, flags, decimals = struct.unpack('<HIBHB', data[pos:pos+10])
170-
177+
charset, column_length, type, flags, decimals = _STRUCT_FIXED_FIELDS.unpack_from(data, pos)
178+
171179
return ColumnDefinitionPacket(
172180
catalog_bytes,
173181
schema_bytes,

mariadb/impl/message/server/eof_packet.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99

1010
from typing import TYPE_CHECKING
1111
from ...completion import Completion
12-
# No longer need PacketBuffer import
12+
1313
if TYPE_CHECKING:
1414
from ...client.context import Context
1515

1616
from mariadb_shared import constants
1717

18+
# Pre-compute constant
19+
_PS_OUT_PARAMS_MASK = constants.STATUS.PS_OUT_PARAMS
20+
21+
1822
class EofPacket(Completion):
1923
"""
2024
EOF Packet from MariaDB server
@@ -29,27 +33,26 @@ class EofPacket(Completion):
2933
In that case, use OkPacket.decode() instead.
3034
"""
3135
__slots__ = (
32-
'warning_count',
3336
'server_status',
3437
)
38+
3539
def __init__(
3640
self,
37-
warning_count: int = 0,
38-
server_status: int = 0,
39-
is_output_parameters: bool = False
41+
warning_count: int,
42+
server_status: int,
4043
):
4144
"""Initialize EOF packet with warning count and server status"""
4245
self.affected_rows = 0
4346
self.insert_id = 0
4447
self.warning_count = warning_count
4548
self.result_set = None
4649

47-
# EofPacket-specific fields
50+
# EofPacket-specific field
4851
self.server_status = server_status
4952

5053
def is_output_parameters(self) -> bool:
5154
"""Check if completion has output parameters"""
52-
return (self.server_status & constants.STATUS.PS_OUT_PARAMS) != 0
55+
return (self.server_status & _PS_OUT_PARAMS_MASK) != 0
5356

5457
@staticmethod
5558
def decode(data: memoryview, context: 'Context') -> 'EofPacket':
@@ -61,12 +64,4 @@ def decode(data: memoryview, context: 'Context') -> 'EofPacket':
6164
context.server_status = server_status
6265
context.warning_count = warning_count
6366

64-
# Check if this marks output parameters (PS_OUT_PARAMS flag)
65-
is_output_parameters = (server_status & constants.STATUS.PS_OUT_PARAMS) != 0
66-
67-
return EofPacket(
68-
warning_count,
69-
server_status,
70-
is_output_parameters
71-
)
72-
67+
return EofPacket(warning_count, server_status)

mariadb/impl/message/server/error_packet.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
Based on MySQL/MariaDB protocol error packet structure.
88
"""
99

10-
import struct
1110
from typing import TYPE_CHECKING, Optional
1211

1312
if TYPE_CHECKING:
1413
from ...client.context import Context
1514
from ...client.exception_factory import ExceptionFactory
1615
from ...client.socket.payload_parser import PayloadParser
17-
# No longer need PacketBuffer import
16+
17+
# Pre-compute constants
18+
_HASH_MARKER = 0x23 # '#'
19+
_DEFAULT_SQL_STATE = "HY000"
1820

1921

2022
class ErrorPacket:
@@ -35,11 +37,12 @@ class ErrorPacket:
3537
'sql_state',
3638
'error_message',
3739
)
40+
3841
def __init__(
3942
self,
4043
error_code: int,
41-
sql_state: str = "HY000",
42-
error_message: str = "",
44+
sql_state: str,
45+
error_message: str,
4346
):
4447
"""Initialize error packet with error code, SQL state, and message"""
4548
self.error_code = error_code
@@ -54,34 +57,41 @@ def is_output_parameters(self) -> bool:
5457
def decode(data: memoryview, context: Optional['Context'] = None) -> 'ErrorPacket':
5558
"""Decode error packet from bytearray with optional context"""
5659
parser = PayloadParser(data)
57-
parser.read_byte()
60+
parser.skip(1) # Skip error marker (0xFF) - skip is faster than read_byte if we don't use value
5861
error_code = parser.read_uint16()
59-
sql_state = "HY000" # Default SQL state
6062

61-
# Check for SQL state marker '#' (0x23)
62-
if parser.has_remaining() and parser.get_byte() == 0x23: # '#' symbol
63-
parser.read_byte() # Skip '#' marker
63+
# Fast path: check for SQL state marker
64+
if parser.has_remaining() and parser.get_byte() == _HASH_MARKER:
65+
parser.skip(1) # Skip '#' marker
66+
6467
# SQL state (5 bytes)
6568
if parser.remaining_bytes() >= 5:
66-
sql_state = bytes(parser.read_bytes(5)).decode('ascii')
69+
sql_state = parser.read_bytes(5).decode('ascii')
6770
else:
6871
raise IOError("Invalid error packet: SQL state truncated")
72+
else:
73+
sql_state = _DEFAULT_SQL_STATE
6974

70-
error_message = bytes(parser.read_remaining()).decode('utf-8', errors='replace')
71-
return ErrorPacket(
72-
error_code,
73-
sql_state,
74-
error_message
75-
)
76-
75+
# Decode error message - remove unnecessary bytes() wrapper if read_remaining returns bytes
76+
error_message = parser.read_remaining().decode('utf-8', errors='replace')
77+
78+
return ErrorPacket(error_code, sql_state, error_message)
7779

7880
def toError(self, exception_factory: 'ExceptionFactory', sql: Optional[str] = None):
79-
return exception_factory.create_exception(self.error_message, self.sql_state, self.error_code, sql)
81+
"""Convert to exception"""
82+
return exception_factory.create_exception(
83+
self.error_message,
84+
self.sql_state,
85+
self.error_code,
86+
sql
87+
)
8088

8189
def __repr__(self) -> str:
90+
"""String representation for debugging"""
8291
return (f"ErrorPacket(error_code={self.error_code}, "
8392
f"sql_state='{self.sql_state}', "
8493
f"error_message='{self.error_message}')")
8594

8695
def __str__(self) -> str:
87-
return f"[{self.error_code}] ({self.sql_state}): {self.error_message}"
96+
"""Human-readable string representation"""
97+
return f"[{self.error_code}] ({self.sql_state}): {self.error_message}"

mariadb/impl/message/server/prepare_stmt_packet.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@
77
Based on MySQL/MariaDB protocol COM_STMT_PREPARE response structure.
88
"""
99

10-
from typing import TYPE_CHECKING, Optional
10+
import struct
11+
from typing import TYPE_CHECKING, List, Optional
1112

1213
from .column_definition_packet import ColumnDefinitionPacket
13-
from ...client.socket.payload_parser import PayloadParser
14-
# No longer need PacketBuffer import
14+
1515
if TYPE_CHECKING:
1616
from ...client.context import Context
1717

18+
# Pre-compile struct format for faster unpacking
19+
# Format: skip(B), statement_id(I), column_count(H), parameter_count(H), reserved(B), warning_count(H)
20+
_STRUCT_PREPARE_RESPONSE = struct.Struct('<BIHHBH')
21+
1822

1923
class PrepareStmtPacket:
2024
"""
@@ -49,10 +53,10 @@ def __init__(
4953
statement_id: int,
5054
column_count: int,
5155
parameter_count: int,
52-
warning_count: int = 0,
53-
sql: str = None
56+
warning_count: int,
57+
sql: Optional[str]
5458
):
55-
"""Initialize prepare statement packet with statement ID, column count, and parameter count"""
59+
"""Initialize prepare statement packet"""
5660
self.statement_id = statement_id
5761
self.column_count = column_count
5862
self.parameter_count = parameter_count
@@ -63,16 +67,10 @@ def __init__(
6367
self.closed = False
6468

6569
@staticmethod
66-
def decode(data: memoryview, context: Optional['Context'] = None, sql: str = None) -> 'PrepareStmtPacket':
67-
"""Decode COM_STMT_PREPARE response packet from bytearray with optional context"""
68-
parser = PayloadParser(data)
69-
70-
parser.read_byte() # Skip OK marker (0x00)
71-
statement_id = parser.read_uint32()
72-
column_count = parser.read_uint16()
73-
parameter_count = parser.read_uint16()
74-
parser.read_byte() # Skip reserved byte (0x00)
75-
warning_count = parser.read_uint16()
70+
def decode(data: memoryview, context: Optional['Context'] = None, sql: Optional[str] = None) -> 'PrepareStmtPacket':
71+
"""Decode COM_STMT_PREPARE response packet (optimized)"""
72+
# Unpack all fields in one operation using pre-compiled struct
73+
_, statement_id, column_count, parameter_count, _, warning_count = _STRUCT_PREPARE_RESPONSE.unpack_from(data, 0)
7674

7775
# Update context if provided
7876
if context:

0 commit comments

Comments
 (0)