From aee62296cc6b507c04045317c95dee7c34a9c587 Mon Sep 17 00:00:00 2001 From: Lawrin Novitsky Date: Mon, 3 Nov 2025 03:07:25 +0100 Subject: [PATCH 1/4] Constants and classes MariaDB COM_STMT_BULK_EXECUTE command Added StmtBulkExecuteParamsFlags bitflags for packet's bulk flags and MariadbBulkIndicator enum to define possible paramater value indicators. Added ComStmtBulkExecuteRequestBuilder class to build COM_STMT_BULK_EXECUTE packet representation in ComStmtBulkExecuteRequest class implementing packet data serialization. --- src/constants.rs | 32 +++++++- src/packets/mod.rs | 180 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 206 insertions(+), 6 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 45e54bb..78cc9c1 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -366,7 +366,7 @@ my_bitflags! { UnknownMariadbCapabilityFlags, u32, - /// Mariadb client capability flags + /// MariaDB client capability flags #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub struct MariadbCapabilities: u32 { /// Permits feedback during long-running operations @@ -431,6 +431,20 @@ my_bitflags! { } } +my_bitflags! { + StmtBulkExecuteParamsFlags, + #[error("Unknown flags in the raw value of StmtBulkExecuteParamsFlags (raw={0:b})")] + UnknownStmtBulkExecuteParamsFlags, + u16, + + /// MySql stmt execute params flags. + #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] + pub struct StmtBulkExecuteParamsFlags: u16 { + const SEND_UNIT_RESULTS = 64_u16; + const SEND_TYPES_TO_SERVER = 128_u16; + } +} + my_bitflags! { ColumnFlags, #[error("Unknown flags in the raw value of ColumnFlags (raw={0:b})")] @@ -528,6 +542,22 @@ pub enum Command { COM_BINLOG_DUMP_GTID, COM_RESET_CONNECTION, COM_END, + COM_STMT_BULK_EXECUTE = 0xfa_u8, +} + +/// MariaDB bulk execute parameter value indicators +#[allow(non_camel_case_types)] +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +#[repr(u8)] +pub enum MariadbBulkIndicator { + /// No special indicator, normal value + BULK_INDICATOR_NONE = 0x00_u8, + /// NULL value + BULK_INDICATOR_NULL = 0x01_u8, + /// For INSERT/UPDATE, value is default. Not used + BULK_INDICATOR_DEFAULT = 0x02_u8, + /// Value is default for insert, Is ignored for update. Not used. + BULK_INDICATOR_IGNORE = 0x03_u8, } /// Type of state change information (part of MySql's Ok packet). diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 9d94c99..9ae5829 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -18,12 +18,13 @@ use std::{ }; use crate::collations::CollationId; +use crate::constants::StmtBulkExecuteParamsFlags; use crate::scramble::create_response_for_ed25519; use crate::{ constants::{ CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN, - MariadbCapabilities, SessionStateType, StatusFlags, StmtExecuteParamFlags, - StmtExecuteParamsFlags, + MariadbBulkIndicator, MariadbCapabilities, SessionStateType, StatusFlags, + StmtExecuteParamFlags, StmtExecuteParamsFlags, }, io::{BufMutExt, ParseBuf}, misc::{ @@ -2762,6 +2763,171 @@ impl MySerialize for ComStmtClose { } } +/// Sends array of parameters to the server for the bulk execution of a prepared statement with +/// COM_STMT_BULK_EXECUTE command. +#[derive(Debug, Clone, PartialEq)] +pub struct ComStmtBulkExecuteRequestBuilder { + pub stmt_id: u32, + pub with_types: bool, + pub paramset: Vec>, + pub payload_len: usize, + pub max_payload_len: usize, /* max_allowed_packet(if known) - 4 */ +} + +impl ComStmtBulkExecuteRequestBuilder { + pub fn new(stmt_id: u32, max_payload: usize) -> Self { + Self { + stmt_id, + with_types: true, + paramset: Vec::new(), + payload_len: 0, + max_payload_len: max_payload, + } + } + + pub fn next(&mut self) -> () { + self.with_types = false; + self.paramset.clear(); + self.payload_len = 0; + } + pub fn add_row(&mut self, params: &[Value]) -> bool { + if self.with_types && self.payload_len == 0 { + self.payload_len = params.len() * 2; + } + let mut data_len = 0; + for p in params { + match p.bin_len() as usize { + 0 => data_len += 1, // NULLs take 1 byte for the indicator + x => data_len += x + 1, // non-NULLs take their length + 1 byte for the indicator + } + } + // It should be really total packet len(+7 + 4)compared against max allowed packet size, not MAX_PAYLOAD_LEN + if 7 + self.payload_len + data_len > self.max_payload_len { + return true; + } + self.paramset.push(params.to_vec()); + self.payload_len += data_len; + false + } + + pub fn has_rows(&self) -> bool { + !self.paramset.is_empty() + } + + pub fn build(&self) -> ComStmtBulkExecuteRequest { + ComStmtBulkExecuteRequest { + com_stmt_bulk_execute: ConstU8::new(), + stmt_id: RawInt::new(self.stmt_id), + bulk_flags: if self.with_types { + Const::new(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER) + } else { + Const::new(StmtBulkExecuteParamsFlags::empty()) + }, + params: &self.paramset, + } + } +} + +define_header!( + ComStmtBulkExecuteHeader, + COM_STMT_BULK_EXECUTE, + InvalidComStmtBulkExecuteHeader +); + +#[derive(Debug, Clone, PartialEq)] +pub struct ComStmtBulkExecuteRequest<'a> { + com_stmt_bulk_execute: ComStmtBulkExecuteHeader, + stmt_id: RawInt, + bulk_flags: Const, + // max params / bits per byte = 8192 + params: &'a Vec>, +} + +impl<'a> ComStmtBulkExecuteRequest<'a> { + pub fn stmt_id(&self) -> u32 { + self.stmt_id.0 + } + + pub fn bulk_flags(&self) -> StmtBulkExecuteParamsFlags { + self.bulk_flags.0 + } + + pub fn params(&self) -> &[Vec] { + self.params.as_ref() + } +} + +impl MySerialize for ComStmtBulkExecuteRequest<'_> { + fn serialize(&self, buf: &mut Vec) { + self.com_stmt_bulk_execute.serialize(&mut *buf); + self.stmt_id.serialize(&mut *buf); + self.bulk_flags.serialize(&mut *buf); + + if self + .bulk_flags + .0 + .contains(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER) + { + for param in &self.params[0] { + let (column_type, flags) = match param { + Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()), + Value::Bytes(_) => ( + ColumnType::MYSQL_TYPE_VAR_STRING, + StmtExecuteParamFlags::empty(), + ), + Value::Int(_) => ( + ColumnType::MYSQL_TYPE_LONGLONG, + StmtExecuteParamFlags::empty(), + ), + Value::UInt(_) => ( + ColumnType::MYSQL_TYPE_LONGLONG, + StmtExecuteParamFlags::UNSIGNED, + ), + Value::Float(_) => { + (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()) + } + Value::Double(_) => ( + ColumnType::MYSQL_TYPE_DOUBLE, + StmtExecuteParamFlags::empty(), + ), + Value::Date(..) => ( + ColumnType::MYSQL_TYPE_DATETIME, + StmtExecuteParamFlags::empty(), + ), + Value::Time(..) => { + (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()) + } + }; + buf.put_slice(&[column_type as u8, flags.bits()]); + } + } + + for row in self.params { + for param in row { + match param { + Value::Int(_) + | Value::UInt(_) + | Value::Float(_) + | Value::Double(_) + | Value::Date(..) + | Value::Time(..) => { + buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL + param.serialize(buf); + } + Value::Bytes(_) => { + buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL + param.serialize(buf); + } + Value::NULL => { + buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NULL as u8); // NULL indicator + } + } + } + } + } +} +// ------------------------------------------------------------------------------ + define_header!( ComRegisterSlaveHeader, COM_REGISTER_SLAVE, @@ -4129,7 +4295,7 @@ mod test { fn should_parse_handshake_packet_with_mariadb_ext_capabilities() { const HSP: &[u8] = b"\x0a5.5.5-11.4.7-MariaDB-log\x00\x0b\x00\ \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\ - \x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x2a\x34\x64\ + \x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x00\x00\x2a\x34\x64\ \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00"; let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap(); @@ -4150,6 +4316,7 @@ mod test { assert_eq!( hsp.mariadb_ext_capabilities(), MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA + | MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS ); let mut output = Vec::new(); hsp.serialize(&mut output); @@ -4169,7 +4336,10 @@ mod test { None, 1_u32.to_be(), ) - .with_mariadb_ext_capabilities(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA); + .with_mariadb_ext_capabilities( + MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA + | MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS, + ); let mut actual = Vec::new(); response.serialize(&mut actual); @@ -4179,7 +4349,7 @@ mod test { 0x2d, // charset 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved - 0x10, 0x00, 0x00, 0x00, // mariadb capabilities + 0x14, 0x00, 0x00, 0x00, // mariadb capabilities 0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root 0x00, // blank scramble 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, From 99331f105ce76084e1d8c2a0316136707e71072a Mon Sep 17 00:00:00 2001 From: Lawrin Novitsky Date: Mon, 3 Nov 2025 09:52:23 +0100 Subject: [PATCH 2/4] Correction of the elided lifetime --- src/packets/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 9ae5829..45304f4 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -2814,7 +2814,7 @@ impl ComStmtBulkExecuteRequestBuilder { !self.paramset.is_empty() } - pub fn build(&self) -> ComStmtBulkExecuteRequest { + pub fn build(&self) -> ComStmtBulkExecuteRequest<'_> { ComStmtBulkExecuteRequest { com_stmt_bulk_execute: ConstU8::new(), stmt_id: RawInt::new(self.stmt_id), From 8ad5e1e22d2561b275f0960d482db7c882fec98c Mon Sep 17 00:00:00 2001 From: Lawrin Novitsky Date: Mon, 17 Nov 2025 01:31:12 +0100 Subject: [PATCH 3/4] Some chances to address review questions. Added comments, changed add_row interface, added condition to ensure that packet's params vector is not empty and thus its 0-indexed element exist. Changed ComStmtBulkExecuteRequestBuilder::next() to except parameters vector as it only should be called if more parameter rows exist and add_row has to be called right after it. It calls add_row with given parameters vector now. --- src/packets/mod.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 45304f4..0c9fe74 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -2764,7 +2764,8 @@ impl MySerialize for ComStmtClose { } /// Sends array of parameters to the server for the bulk execution of a prepared statement with -/// COM_STMT_BULK_EXECUTE command. +/// COM_STMT_BULK_EXECUTE command. This command is MariaDB only and may not be used for queries w/out +/// parameters and with empty parameter sets. #[derive(Debug, Clone, PartialEq)] pub struct ComStmtBulkExecuteRequestBuilder { pub stmt_id: u32, @@ -2785,24 +2786,33 @@ impl ComStmtBulkExecuteRequestBuilder { } } - pub fn next(&mut self) -> () { + // Resets the builder to start building a new bulk execute request. In particular - without types. + // If it's called - means that there is row to be added that did not fit previous packet. So, it should + // be always followed by add_row(). That is something it can do on its own. + pub fn next(&mut self, params: &Vec) -> () { self.with_types = false; self.paramset.clear(); self.payload_len = 0; + self.add_row(params); } - pub fn add_row(&mut self, params: &[Value]) -> bool { + + // Adds a new row of parameters to the bulk execute request. + // Returns true if adding this row would exceed the max allowed packet size. + pub fn add_row(&mut self, params: &Vec) -> bool { if self.with_types && self.payload_len == 0 { self.payload_len = params.len() * 2; } let mut data_len = 0; for p in params { + // bin_len() includes lenght encoding bytes match p.bin_len() as usize { 0 => data_len += 1, // NULLs take 1 byte for the indicator x => data_len += x + 1, // non-NULLs take their length + 1 byte for the indicator } } - // It should be really total packet len(+7 + 4)compared against max allowed packet size, not MAX_PAYLOAD_LEN - if 7 + self.payload_len + data_len > self.max_payload_len { + // 7 = 1(command id) + 4 (stmt_id) + 2 (flags). If it's 1st row - we take it to return error + // later(when the packet is sent). In this way we can avoid eternal loops of trying to add this row. + if 7 + self.payload_len + data_len > self.max_payload_len && !self.paramset.is_empty() { return true; } self.paramset.push(params.to_vec()); @@ -2867,6 +2877,7 @@ impl MySerialize for ComStmtBulkExecuteRequest<'_> { .bulk_flags .0 .contains(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER) + && !self.params.is_empty() { for param in &self.params[0] { let (column_type, flags) = match param { From 8f2e84c1ae215740d67b52184119985caab0c8be Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 5 Dec 2025 12:09:10 +0300 Subject: [PATCH 4/4] Prepare COM_STMT_BULK_EXECUTE for merge * rename `StmtBulkExecuteParamsFlags` -> `StmtBulkExecuteFlags` * add `ComStmtBulkExecuteRequestBuilder::{add_params, build_iter, build_iter_params}` * redefine `ComStmtBulkExecuteRequest` and impl `MyDeserialize` * separate `BuildExecuteRequestError` and `BuildExecuteRequestBuilderError` --- src/constants.rs | 28 ++- src/misc/raw/seq.rs | 44 ++++ src/packets/mod.rs | 535 ++++++++++++++++++++++++++++++++++---------- src/params.rs | 99 ++++++-- 4 files changed, 568 insertions(+), 138 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 78cc9c1..4d4d6b7 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -432,14 +432,14 @@ my_bitflags! { } my_bitflags! { - StmtBulkExecuteParamsFlags, + StmtBulkExecuteFlags, #[error("Unknown flags in the raw value of StmtBulkExecuteParamsFlags (raw={0:b})")] UnknownStmtBulkExecuteParamsFlags, u16, /// MySql stmt execute params flags. #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] - pub struct StmtBulkExecuteParamsFlags: u16 { + pub struct StmtBulkExecuteFlags: u16 { const SEND_UNIT_RESULTS = 64_u16; const SEND_TYPES_TO_SERVER = 128_u16; } @@ -545,6 +545,10 @@ pub enum Command { COM_STMT_BULK_EXECUTE = 0xfa_u8, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[error("Unknown MariaDB bulk execute parameter value indicator {}", _0)] +pub struct UnknownMariadbBulkIndicator(pub u8); + /// MariaDB bulk execute parameter value indicators #[allow(non_camel_case_types)] #[derive(Clone, Copy, Eq, PartialEq, Debug)] @@ -560,6 +564,26 @@ pub enum MariadbBulkIndicator { BULK_INDICATOR_IGNORE = 0x03_u8, } +impl From for u8 { + fn from(x: MariadbBulkIndicator) -> u8 { + x as u8 + } +} + +impl TryFrom for MariadbBulkIndicator { + type Error = UnknownMariadbBulkIndicator; + + fn try_from(value: u8) -> Result { + match value { + 0x00 => Ok(Self::BULK_INDICATOR_NONE), + 0x01 => Ok(Self::BULK_INDICATOR_NULL), + 0x02 => Ok(Self::BULK_INDICATOR_DEFAULT), + 0x03 => Ok(Self::BULK_INDICATOR_IGNORE), + x => Err(UnknownMariadbBulkIndicator(x)), + } + } +} + /// Type of state change information (part of MySql's Ok packet). #[allow(non_camel_case_types)] #[derive(Clone, Copy, Eq, PartialEq, Debug)] diff --git a/src/misc/raw/seq.rs b/src/misc/raw/seq.rs index ccb1825..5b40138 100644 --- a/src/misc/raw/seq.rs +++ b/src/misc/raw/seq.rs @@ -33,6 +33,12 @@ impl Deref for Seq<'_, T, U> { } } +impl Seq<'static, T, U> { + pub fn empty() -> Seq<'static, T, U> { + Self(Cow::Borrowed(&[]), PhantomData) + } +} + impl<'a, T: Clone, U> Seq<'a, T, U> { pub fn new(s: impl Into>) -> Self { Self(s.into(), PhantomData) @@ -111,6 +117,44 @@ pub trait SeqRepr { T: MyDeserialize<'de, Ctx = ()>; } +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct Unknown; + +/// Unknown number of elements. +impl SeqRepr for Unknown { + const MAX_LEN: usize = usize::MAX; + const SIZE: Option = None; + type Ctx = usize; + + fn serialize(seq: &[T], buf: &mut Vec) { + for x in seq.iter() { + x.serialize(&mut *buf); + } + } + + fn deserialize<'de, T>(len: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result> + where + T: Clone, + T: MyDeserialize<'de, Ctx = ()>, + { + let mut seq = Vec::with_capacity(len); + match T::SIZE { + Some(count) => { + let mut buf: ParseBuf<'_> = buf.parse(count * len)?; + for _ in 0..len { + seq.push(buf.parse(())?); + } + } + None => { + for _ in 0..len { + seq.push(buf.parse(())?); + } + } + } + Ok(Cow::Owned(seq)) + } +} + macro_rules! impl_seq_repr { ($t:ty, $name:ident) => { impl SeqRepr for $name { diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 0c9fe74..9b9f771 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -11,20 +11,24 @@ use bytes::BufMut; use regex::bytes::Regex; use uuid::Uuid; -use std::str::FromStr; -use std::sync::{Arc, LazyLock}; use std::{ - borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData, + borrow::Cow, + cmp::max, + collections::HashMap, + convert::TryFrom, + fmt, io, + marker::PhantomData, + mem, + str::FromStr, + sync::{Arc, LazyLock}, }; -use crate::collations::CollationId; -use crate::constants::StmtBulkExecuteParamsFlags; -use crate::scramble::create_response_for_ed25519; use crate::{ + collations::CollationId, constants::{ CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN, MariadbBulkIndicator, MariadbCapabilities, SessionStateType, StatusFlags, - StmtExecuteParamFlags, StmtExecuteParamsFlags, + StmtBulkExecuteFlags, StmtExecuteParamFlags, StmtExecuteParamsFlags, }, io::{BufMutExt, ParseBuf}, misc::{ @@ -36,12 +40,14 @@ use crate::{ U32Bytes, }, int::{ConstU8, ConstU32, LeU16, LeU24, LeU32, LeU32LowerHalf, LeU32UpperHalf, LeU64}, - seq::Seq, + seq::{Seq, Unknown}, }, unexpected_buf_eof, }, + params::{Params, ParamsError}, proto::{MyDeserialize, MySerialize}, - value::{ClientSide, SerializationSide, Value}, + scramble::create_response_for_ed25519, + value::{BinValue, ClientSide, SerializationSide, Value, ValueDeserializer}, }; use self::session_state_change::SessionStateChange; @@ -2767,74 +2773,199 @@ impl MySerialize for ComStmtClose { /// COM_STMT_BULK_EXECUTE command. This command is MariaDB only and may not be used for queries w/out /// parameters and with empty parameter sets. #[derive(Debug, Clone, PartialEq)] -pub struct ComStmtBulkExecuteRequestBuilder { - pub stmt_id: u32, - pub with_types: bool, - pub paramset: Vec>, - pub payload_len: usize, - pub max_payload_len: usize, /* max_allowed_packet(if known) - 4 */ +pub struct ComStmtBulkExecuteRequestBuilder<'a> { + stmt_id: u32, + with_types: bool, + params_set: Vec>, + payload_len: usize, + arity: Option, + named_params: Option<&'a [Vec]>, + max_payload_len: usize, /* max_allowed_packet(if known) - 4 */ } -impl ComStmtBulkExecuteRequestBuilder { - pub fn new(stmt_id: u32, max_payload: usize) -> Self { - Self { +impl ComStmtBulkExecuteRequestBuilder<'_> { + pub fn new( + stmt_id: u32, + max_allowed_packet: usize, + ) -> ComStmtBulkExecuteRequestBuilder<'static> { + ComStmtBulkExecuteRequestBuilder { stmt_id, with_types: true, - paramset: Vec::new(), + params_set: Vec::new(), payload_len: 0, - max_payload_len: max_payload, + arity: None, + named_params: None, + max_payload_len: max_allowed_packet - 4, } } - // Resets the builder to start building a new bulk execute request. In particular - without types. - // If it's called - means that there is row to be added that did not fit previous packet. So, it should - // be always followed by add_row(). That is something it can do on its own. - pub fn next(&mut self, params: &Vec) -> () { - self.with_types = false; - self.paramset.clear(); - self.payload_len = 0; - self.add_row(params); + /// Use named parameters in the same order they were given in the SQL statement. + pub fn with_named_params<'a>( + self, + named_params: Option<&'a [Vec]>, + ) -> ComStmtBulkExecuteRequestBuilder<'a> { + ComStmtBulkExecuteRequestBuilder { + stmt_id: self.stmt_id, + with_types: self.with_types, + params_set: self.params_set, + payload_len: self.payload_len, + arity: self.arity, + named_params, + max_payload_len: self.max_payload_len, + } + } + + /// See [`ComStmtBulkExecuteRequestBuilder::add_row`]. + pub fn add_params( + &mut self, + params: impl Into, + ) -> Result>, BulkExecuteRequestBuilderError> { + self._add_params(params.into()) + } + + fn _add_params( + &mut self, + params: Params, + ) -> Result>, BulkExecuteRequestBuilderError> { + self.add_row(params.into_values(self.named_params)?) } - // Adds a new row of parameters to the bulk execute request. - // Returns true if adding this row would exceed the max allowed packet size. - pub fn add_row(&mut self, params: &Vec) -> bool { + /// Adds a new row of parameters to the bulk execute request. + /// + /// Returns row back if adding it would exceed the max allowed packet size — practically + /// this means that you should call [`ComStmtBulkExecuteRequestBuilder::build`] to consume + /// all the rows added so far and then continue adding more rows to the next bulk request + /// starting from this returned row. + /// + /// # Error + /// + /// This function will emit an error if params' arity differs from previous + /// rows added to this builder or if row is larger than the max payload length + /// (max allowed packet - 4) + pub fn add_row( + &mut self, + params: Vec, + ) -> Result>, BulkExecuteRequestBuilderError> { + let arity = self.arity.get_or_insert(params.len()); + if params.len() != *arity { + return Err(BulkExecuteRequestError::MixedArity.into()); + } + if self.with_types && self.payload_len == 0 { self.payload_len = params.len() * 2; } + let mut data_len = 0; - for p in params { - // bin_len() includes lenght encoding bytes + + for p in ¶ms { + // bin_len() includes length encoding bytes match p.bin_len() as usize { 0 => data_len += 1, // NULLs take 1 byte for the indicator x => data_len += x + 1, // non-NULLs take their length + 1 byte for the indicator } } + // 7 = 1(command id) + 4 (stmt_id) + 2 (flags). If it's 1st row - we take it to return error // later(when the packet is sent). In this way we can avoid eternal loops of trying to add this row. - if 7 + self.payload_len + data_len > self.max_payload_len && !self.paramset.is_empty() { - return true; + if 7 + self.payload_len + data_len > self.max_payload_len { + if self.params_set.is_empty() { + return Err(BulkExecuteRequestError::RowTooLarge(params).into()); + } + return Ok(Some(params)); } - self.paramset.push(params.to_vec()); + + self.params_set.push(params); self.payload_len += data_len; - false + + Ok(None) } pub fn has_rows(&self) -> bool { - !self.paramset.is_empty() + !self.params_set.is_empty() } - pub fn build(&self) -> ComStmtBulkExecuteRequest<'_> { - ComStmtBulkExecuteRequest { - com_stmt_bulk_execute: ConstU8::new(), - stmt_id: RawInt::new(self.stmt_id), - bulk_flags: if self.with_types { - Const::new(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER) - } else { - Const::new(StmtBulkExecuteParamsFlags::empty()) - }, - params: &self.paramset, + /// Builds `COM_STMT_BULK_EXECUTE` consuming rows added so far. + /// + /// After the call you can continue using [`ComStmtBulkExecuteRequestBuilder::add_row`] + /// to build next bulk request for this statement. + /// + /// # Error + /// + /// This will error if no rows was added to the builder. + pub fn build(&mut self) -> Result, BulkExecuteRequestError> { + let bulk_flags = if self.with_types { + StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER + } else { + StmtBulkExecuteFlags::empty() + }; + self.with_types = false; + self.payload_len = 0; + ComStmtBulkExecuteRequest::new(self.stmt_id, bulk_flags, mem::take(&mut self.params_set)) + } + + /// See [`ComStmtBulkExecuteRequestBuilder::build_iter`]. + pub fn build_params_iter( + &mut self, + input: impl IntoIterator>, + ) -> impl Iterator, BulkExecuteRequestBuilderError>> + { + let mut done = false; + + macro_rules! transpose { + ($e:expr) => { + match $e { + Ok(x) => x, + Err(e) => { + done = true; + return Some(Err(e.into())); + } + } + }; } + + let mut input = input.into_iter().map(Into::into); + let mut stack = None; + std::iter::from_fn(move || { + if done { + return None; + } + + let mut params_iter = stack + .take() + .map(Params::Positional) + .into_iter() + .chain(input.by_ref()); + + while let Some(params) = params_iter.next() { + if let Some(params) = transpose!(self.add_params(params)) { + stack = Some(params); + return Some(Ok(transpose!(self.build()))); + } + } + + if self.has_rows() { + return Some(Ok(transpose!(self.build()))); + } + + done = true; + None + }) + } + + /// It's a convenient wrapper over [`ComStmtBulkExecuteRequestBuilder::add_row`] + /// and [`ComStmtBulkExecuteRequestBuilder::build`] that converts a rows iterator + /// to a bulk requests iterator. + /// + /// # Error + /// + /// This won't error if the input iterator contains no rows — it'll just emit no bulk requests + /// but the iterator will emit an error if it encounters a row with different arity. + pub fn build_iter( + &mut self, + input: impl IntoIterator>, + ) -> impl Iterator, BulkExecuteRequestBuilderError>> + { + self.build_params_iter(input.into_iter().map(|x| Params::Positional(x))) } } @@ -2844,100 +2975,262 @@ define_header!( InvalidComStmtBulkExecuteHeader ); +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StmtBulkExecuteParamType { + r#type: Const, + flags: Const, +} + +impl StmtBulkExecuteParamType { + pub fn new(r#type: ColumnType, flags: StmtExecuteParamFlags) -> Self { + Self { + r#type: Const::new(r#type), + flags: Const::new(flags), + } + } + + pub fn from_value(value: &Value) -> Self { + let (r#type, flags) = match value { + Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()), + Value::Bytes(_) => ( + ColumnType::MYSQL_TYPE_VAR_STRING, + StmtExecuteParamFlags::empty(), + ), + Value::Int(_) => ( + ColumnType::MYSQL_TYPE_LONGLONG, + StmtExecuteParamFlags::empty(), + ), + Value::UInt(_) => ( + ColumnType::MYSQL_TYPE_LONGLONG, + StmtExecuteParamFlags::UNSIGNED, + ), + Value::Float(_) => (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()), + Value::Double(_) => ( + ColumnType::MYSQL_TYPE_DOUBLE, + StmtExecuteParamFlags::empty(), + ), + Value::Date(..) => ( + ColumnType::MYSQL_TYPE_DATETIME, + StmtExecuteParamFlags::empty(), + ), + Value::Time(..) => (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()), + }; + + Self { + r#type: Const::new(r#type), + flags: Const::new(flags), + } + } +} + +impl<'de> MyDeserialize<'de> for StmtBulkExecuteParamType { + const SIZE: Option = Some(3); + type Ctx = (); + + fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result { + Ok(Self { + r#type: buf.parse(())?, + flags: buf.parse(())?, + }) + } +} + +impl MySerialize for StmtBulkExecuteParamType { + fn serialize(&self, buf: &mut Vec) { + self.r#type.serialize(&mut *buf); + self.flags.serialize(&mut *buf); + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct StmtBulkExecuteParamValues { + values: Vec, +} + +impl StmtBulkExecuteParamValues { + pub fn new(values: impl IntoIterator) -> Self { + let values = values + .into_iter() + .map(StmtBulkExecuteParamValue::new) + .collect(); + Self { values } + } +} + +impl<'de> MyDeserialize<'de> for StmtBulkExecuteParamValues { + const SIZE: Option = None; + type Ctx = Vec<(ColumnType, ColumnFlags)>; + + fn deserialize(params: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result { + let mut values = Vec::with_capacity(params.len()); + for param in params { + values.push(StmtBulkExecuteParamValue::deserialize(param, &mut *buf)?); + } + Ok(Self { values }) + } +} + +impl MySerialize for StmtBulkExecuteParamValues { + fn serialize(&self, buf: &mut Vec) { + for value in &self.values { + value.serialize(&mut *buf); + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct StmtBulkExecuteParamValue { + indicator: Const, + value: Value, +} + +impl StmtBulkExecuteParamValue { + pub fn new(value: Value) -> Self { + let indicator = if matches!(value, Value::NULL) { + MariadbBulkIndicator::BULK_INDICATOR_NULL + } else { + MariadbBulkIndicator::BULK_INDICATOR_NONE + }; + Self { + indicator: Const::new(indicator), + value, + } + } +} + +impl<'de> MyDeserialize<'de> for StmtBulkExecuteParamValue { + const SIZE: Option = None; + type Ctx = (ColumnType, ColumnFlags); + + fn deserialize(ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result { + let indicator = buf.parse::>(())?; + let value = if *indicator == MariadbBulkIndicator::BULK_INDICATOR_NULL { + Value::NULL + } else { + ValueDeserializer::::deserialize(ctx, buf)?.0 + }; + + Ok(Self { indicator, value }) + } +} + +impl MySerialize for StmtBulkExecuteParamValue { + fn serialize(&self, buf: &mut Vec) { + self.indicator.serialize(&mut *buf); + self.value.serialize(&mut *buf); + } +} + +#[derive(Debug, Clone, PartialEq, thiserror::Error)] +pub enum BulkExecuteRequestError { + #[error("No parameters values given for the bulk operation")] + NoParams, + #[error("Mixed statement arity")] + MixedArity, + #[error("Got row bigger than max payload length")] + RowTooLarge(Vec), +} + +#[derive(Debug, Clone, PartialEq, thiserror::Error)] +pub enum BulkExecuteRequestBuilderError { + #[error(transparent)] + Request(#[from] BulkExecuteRequestError), + #[error(transparent)] + Params(#[from] ParamsError), +} + #[derive(Debug, Clone, PartialEq)] pub struct ComStmtBulkExecuteRequest<'a> { - com_stmt_bulk_execute: ComStmtBulkExecuteHeader, + header: ComStmtBulkExecuteHeader, stmt_id: RawInt, - bulk_flags: Const, - // max params / bits per byte = 8192 - params: &'a Vec>, + bulk_flags: Const, + types: Seq<'a, StmtBulkExecuteParamType, Unknown>, + values: StmtBulkExecuteParamValues, } impl<'a> ComStmtBulkExecuteRequest<'a> { - pub fn stmt_id(&self) -> u32 { - self.stmt_id.0 - } + pub fn new( + stmt_id: u32, + bulk_flags: StmtBulkExecuteFlags, + values: Vec>, + ) -> Result { + let first = values.first().ok_or(BulkExecuteRequestError::NoParams)?; + let arity = first.len(); + + let types = if bulk_flags.contains(StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER) { + Seq::new( + first + .iter() + .map(StmtBulkExecuteParamType::from_value) + .collect::>(), + ) + } else { + Seq::empty() + }; - pub fn bulk_flags(&self) -> StmtBulkExecuteParamsFlags { - self.bulk_flags.0 + for values in &values { + if values.len() != arity { + return Err(BulkExecuteRequestError::MixedArity); + } + } + + Ok(Self { + header: ConstU8::new(), + stmt_id: RawInt::new(stmt_id), + bulk_flags: Const::new(bulk_flags), + types, + values: StmtBulkExecuteParamValues::new(values.into_iter().flatten()), + }) } +} - pub fn params(&self) -> &[Vec] { - self.params.as_ref() +impl<'de> MyDeserialize<'de> for ComStmtBulkExecuteRequest<'de> { + const SIZE: Option = None; + type Ctx = Vec<(ColumnType, ColumnFlags)>; + + fn deserialize(mut params: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result { + let header = buf.parse(())?; + let stmt_id = buf.parse(())?; + let bulk_flags: Const = buf.parse(())?; + let types = if bulk_flags.contains(StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER) { + let types: Seq<'_, StmtBulkExecuteParamType, Unknown> = buf.parse(params.len())?; + for (s_type, r#type) in types.iter().zip(params.iter_mut()) { + r#type.0 = *s_type.r#type; + r#type.1.set( + ColumnFlags::UNSIGNED_FLAG, + s_type.flags.contains(StmtExecuteParamFlags::UNSIGNED), + ); + } + types + } else { + Seq::empty() + }; + let values = buf.parse(params)?; + + Ok(Self { + header, + stmt_id, + bulk_flags, + types, + values, + }) } } -impl MySerialize for ComStmtBulkExecuteRequest<'_> { +impl<'a> MySerialize for ComStmtBulkExecuteRequest<'a> { fn serialize(&self, buf: &mut Vec) { - self.com_stmt_bulk_execute.serialize(&mut *buf); + self.header.serialize(&mut *buf); self.stmt_id.serialize(&mut *buf); self.bulk_flags.serialize(&mut *buf); - if self .bulk_flags - .0 - .contains(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER) - && !self.params.is_empty() + .contains(StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER) { - for param in &self.params[0] { - let (column_type, flags) = match param { - Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()), - Value::Bytes(_) => ( - ColumnType::MYSQL_TYPE_VAR_STRING, - StmtExecuteParamFlags::empty(), - ), - Value::Int(_) => ( - ColumnType::MYSQL_TYPE_LONGLONG, - StmtExecuteParamFlags::empty(), - ), - Value::UInt(_) => ( - ColumnType::MYSQL_TYPE_LONGLONG, - StmtExecuteParamFlags::UNSIGNED, - ), - Value::Float(_) => { - (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()) - } - Value::Double(_) => ( - ColumnType::MYSQL_TYPE_DOUBLE, - StmtExecuteParamFlags::empty(), - ), - Value::Date(..) => ( - ColumnType::MYSQL_TYPE_DATETIME, - StmtExecuteParamFlags::empty(), - ), - Value::Time(..) => { - (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()) - } - }; - buf.put_slice(&[column_type as u8, flags.bits()]); - } - } - - for row in self.params { - for param in row { - match param { - Value::Int(_) - | Value::UInt(_) - | Value::Float(_) - | Value::Double(_) - | Value::Date(..) - | Value::Time(..) => { - buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL - param.serialize(buf); - } - Value::Bytes(_) => { - buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL - param.serialize(buf); - } - Value::NULL => { - buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NULL as u8); // NULL indicator - } - } - } + self.types.serialize(&mut *buf); } + self.values.serialize(&mut *buf); } } -// ------------------------------------------------------------------------------ define_header!( ComRegisterSlaveHeader, diff --git a/src/params.rs b/src/params.rs index 24cbf2c..70773e1 100644 --- a/src/params.rs +++ b/src/params.rs @@ -11,30 +11,30 @@ use std::{ HashMap, hash_map::{Entry, Entry::Occupied}, }, - error::Error, fmt, }; use crate::value::{Value, convert::ToValue}; -/// `FromValue` conversion error. -#[derive(Debug, Eq, PartialEq, Clone)] +/// Missing named parameter for a statement +#[derive(Debug, Eq, PartialEq, Clone, thiserror::Error)] +#[error("Missing named parameter `{}` for statement", String::from_utf8_lossy(&_0))] pub struct MissingNamedParameterError(pub Vec); -impl fmt::Display for MissingNamedParameterError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "Missing named parameter `{}` for statement", - String::from_utf8_lossy(&self.0) - ) - } +#[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error, Hash)] +pub enum ParamsConfusionError { + #[error("Named params given where positional params are expected")] + NamedParamsForPositionalQuery, + #[error("Positional params given where named params are expected")] + PositionalParamsForNamedQuery, } -impl Error for MissingNamedParameterError { - fn description(&self) -> &str { - "Missing named parameter for statement" - } +#[derive(Debug, PartialEq, Eq, Clone, thiserror::Error)] +pub enum ParamsError { + #[error(transparent)] + Missing(#[from] MissingNamedParameterError), + #[error(transparent)] + Confusion(#[from] ParamsConfusionError), } /// Representations of parameters of a prepared statement. @@ -62,8 +62,77 @@ impl fmt::Debug for Params { } impl Params { + /// Converts [`Params`] into a vector of values given the named parameters. + /// + /// `named_params` (if any) must follow the order they were given in the corresponding SQL + /// statement. + pub fn into_values(self, named_params: Option<&[Vec]>) -> Result, ParamsError> { + match self { + Params::Empty => match named_params { + Some(params) => { + if let Some(first) = params.first() { + Err(MissingNamedParameterError(first.clone()).into()) + } else { + Ok(vec![]) + } + } + None => Ok(vec![]), + }, + Params::Positional(values) => match named_params { + Some(named_params) if !named_params.is_empty() => { + Err(ParamsConfusionError::PositionalParamsForNamedQuery.into()) + } + _ => Ok(values), + }, + Params::Named(map) => match named_params { + Some(named_params) if !named_params.is_empty() => { + let mut values = vec![Value::NULL; named_params.len()]; + let mut indexes = Vec::with_capacity(named_params.len()); + for (name, value) in map { + let mut first = None; + for (i, _) in named_params.iter().enumerate().filter(|(_, x)| **x == name) { + indexes.push(i); + if first.is_none() { + first = Some(i); + } else { + values[i] = value.clone(); + } + } + if let Some(first) = first { + values[first] = value; + } + } + if indexes.len() != named_params.len() { + indexes.sort_unstable(); + match indexes.into_iter().enumerate().find(|x| x.0 != x.1) { + Some((missing, _)) => { + Err(MissingNamedParameterError(named_params[missing].clone()) + .into()) + } + None => { + match named_params.last() { + Some(last) => { + Err(MissingNamedParameterError(last.clone()).into()) + } + None => { + // unreachable + Ok(values) + } + } + } + } + } else { + Ok(values) + } + } + _ => Err(ParamsConfusionError::NamedParamsForPositionalQuery.into()), + }, + } + } + /// Will convert named parameters into positional assuming order passed in `named_params` /// attribute. + #[deprecated = "use `into_values` instead"] pub fn into_positional( self, named_params: &[Vec],