diff --git a/src/constants.rs b/src/constants.rs index 45e54bb..4d4d6b7 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -366,7 +366,7 @@ my_bitflags! { UnknownMariadbCapabilityFlags, u32, - /// Mariadb client capability flags + /// MariaDB client capability flags #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub struct MariadbCapabilities: u32 { /// Permits feedback during long-running operations @@ -431,6 +431,20 @@ my_bitflags! { } } +my_bitflags! { + StmtBulkExecuteFlags, + #[error("Unknown flags in the raw value of StmtBulkExecuteParamsFlags (raw={0:b})")] + UnknownStmtBulkExecuteParamsFlags, + u16, + + /// MySql stmt execute params flags. + #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] + pub struct StmtBulkExecuteFlags: u16 { + const SEND_UNIT_RESULTS = 64_u16; + const SEND_TYPES_TO_SERVER = 128_u16; + } +} + my_bitflags! { ColumnFlags, #[error("Unknown flags in the raw value of ColumnFlags (raw={0:b})")] @@ -528,6 +542,46 @@ pub enum Command { COM_BINLOG_DUMP_GTID, COM_RESET_CONNECTION, COM_END, + COM_STMT_BULK_EXECUTE = 0xfa_u8, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[error("Unknown MariaDB bulk execute parameter value indicator {}", _0)] +pub struct UnknownMariadbBulkIndicator(pub u8); + +/// MariaDB bulk execute parameter value indicators +#[allow(non_camel_case_types)] +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +#[repr(u8)] +pub enum MariadbBulkIndicator { + /// No special indicator, normal value + BULK_INDICATOR_NONE = 0x00_u8, + /// NULL value + BULK_INDICATOR_NULL = 0x01_u8, + /// For INSERT/UPDATE, value is default. Not used + BULK_INDICATOR_DEFAULT = 0x02_u8, + /// Value is default for insert, Is ignored for update. Not used. + BULK_INDICATOR_IGNORE = 0x03_u8, +} + +impl From for u8 { + fn from(x: MariadbBulkIndicator) -> u8 { + x as u8 + } +} + +impl TryFrom for MariadbBulkIndicator { + type Error = UnknownMariadbBulkIndicator; + + fn try_from(value: u8) -> Result { + match value { + 0x00 => Ok(Self::BULK_INDICATOR_NONE), + 0x01 => Ok(Self::BULK_INDICATOR_NULL), + 0x02 => Ok(Self::BULK_INDICATOR_DEFAULT), + 0x03 => Ok(Self::BULK_INDICATOR_IGNORE), + x => Err(UnknownMariadbBulkIndicator(x)), + } + } } /// Type of state change information (part of MySql's Ok packet). diff --git a/src/misc/raw/seq.rs b/src/misc/raw/seq.rs index ccb1825..5b40138 100644 --- a/src/misc/raw/seq.rs +++ b/src/misc/raw/seq.rs @@ -33,6 +33,12 @@ impl Deref for Seq<'_, T, U> { } } +impl Seq<'static, T, U> { + pub fn empty() -> Seq<'static, T, U> { + Self(Cow::Borrowed(&[]), PhantomData) + } +} + impl<'a, T: Clone, U> Seq<'a, T, U> { pub fn new(s: impl Into>) -> Self { Self(s.into(), PhantomData) @@ -111,6 +117,44 @@ pub trait SeqRepr { T: MyDeserialize<'de, Ctx = ()>; } +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct Unknown; + +/// Unknown number of elements. +impl SeqRepr for Unknown { + const MAX_LEN: usize = usize::MAX; + const SIZE: Option = None; + type Ctx = usize; + + fn serialize(seq: &[T], buf: &mut Vec) { + for x in seq.iter() { + x.serialize(&mut *buf); + } + } + + fn deserialize<'de, T>(len: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result> + where + T: Clone, + T: MyDeserialize<'de, Ctx = ()>, + { + let mut seq = Vec::with_capacity(len); + match T::SIZE { + Some(count) => { + let mut buf: ParseBuf<'_> = buf.parse(count * len)?; + for _ in 0..len { + seq.push(buf.parse(())?); + } + } + None => { + for _ in 0..len { + seq.push(buf.parse(())?); + } + } + } + Ok(Cow::Owned(seq)) + } +} + macro_rules! impl_seq_repr { ($t:ty, $name:ident) => { impl SeqRepr for $name { diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 9d94c99..9b9f771 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -11,19 +11,24 @@ use bytes::BufMut; use regex::bytes::Regex; use uuid::Uuid; -use std::str::FromStr; -use std::sync::{Arc, LazyLock}; use std::{ - borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData, + borrow::Cow, + cmp::max, + collections::HashMap, + convert::TryFrom, + fmt, io, + marker::PhantomData, + mem, + str::FromStr, + sync::{Arc, LazyLock}, }; -use crate::collations::CollationId; -use crate::scramble::create_response_for_ed25519; use crate::{ + collations::CollationId, constants::{ CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN, - MariadbCapabilities, SessionStateType, StatusFlags, StmtExecuteParamFlags, - StmtExecuteParamsFlags, + MariadbBulkIndicator, MariadbCapabilities, SessionStateType, StatusFlags, + StmtBulkExecuteFlags, StmtExecuteParamFlags, StmtExecuteParamsFlags, }, io::{BufMutExt, ParseBuf}, misc::{ @@ -35,12 +40,14 @@ use crate::{ U32Bytes, }, int::{ConstU8, ConstU32, LeU16, LeU24, LeU32, LeU32LowerHalf, LeU32UpperHalf, LeU64}, - seq::Seq, + seq::{Seq, Unknown}, }, unexpected_buf_eof, }, + params::{Params, ParamsError}, proto::{MyDeserialize, MySerialize}, - value::{ClientSide, SerializationSide, Value}, + scramble::create_response_for_ed25519, + value::{BinValue, ClientSide, SerializationSide, Value, ValueDeserializer}, }; use self::session_state_change::SessionStateChange; @@ -2762,6 +2769,469 @@ impl MySerialize for ComStmtClose { } } +/// Sends array of parameters to the server for the bulk execution of a prepared statement with +/// COM_STMT_BULK_EXECUTE command. This command is MariaDB only and may not be used for queries w/out +/// parameters and with empty parameter sets. +#[derive(Debug, Clone, PartialEq)] +pub struct ComStmtBulkExecuteRequestBuilder<'a> { + stmt_id: u32, + with_types: bool, + params_set: Vec>, + payload_len: usize, + arity: Option, + named_params: Option<&'a [Vec]>, + max_payload_len: usize, /* max_allowed_packet(if known) - 4 */ +} + +impl ComStmtBulkExecuteRequestBuilder<'_> { + pub fn new( + stmt_id: u32, + max_allowed_packet: usize, + ) -> ComStmtBulkExecuteRequestBuilder<'static> { + ComStmtBulkExecuteRequestBuilder { + stmt_id, + with_types: true, + params_set: Vec::new(), + payload_len: 0, + arity: None, + named_params: None, + max_payload_len: max_allowed_packet - 4, + } + } + + /// Use named parameters in the same order they were given in the SQL statement. + pub fn with_named_params<'a>( + self, + named_params: Option<&'a [Vec]>, + ) -> ComStmtBulkExecuteRequestBuilder<'a> { + ComStmtBulkExecuteRequestBuilder { + stmt_id: self.stmt_id, + with_types: self.with_types, + params_set: self.params_set, + payload_len: self.payload_len, + arity: self.arity, + named_params, + max_payload_len: self.max_payload_len, + } + } + + /// See [`ComStmtBulkExecuteRequestBuilder::add_row`]. + pub fn add_params( + &mut self, + params: impl Into, + ) -> Result>, BulkExecuteRequestBuilderError> { + self._add_params(params.into()) + } + + fn _add_params( + &mut self, + params: Params, + ) -> Result>, BulkExecuteRequestBuilderError> { + self.add_row(params.into_values(self.named_params)?) + } + + /// Adds a new row of parameters to the bulk execute request. + /// + /// Returns row back if adding it would exceed the max allowed packet size — practically + /// this means that you should call [`ComStmtBulkExecuteRequestBuilder::build`] to consume + /// all the rows added so far and then continue adding more rows to the next bulk request + /// starting from this returned row. + /// + /// # Error + /// + /// This function will emit an error if params' arity differs from previous + /// rows added to this builder or if row is larger than the max payload length + /// (max allowed packet - 4) + pub fn add_row( + &mut self, + params: Vec, + ) -> Result>, BulkExecuteRequestBuilderError> { + let arity = self.arity.get_or_insert(params.len()); + if params.len() != *arity { + return Err(BulkExecuteRequestError::MixedArity.into()); + } + + if self.with_types && self.payload_len == 0 { + self.payload_len = params.len() * 2; + } + + let mut data_len = 0; + + for p in ¶ms { + // bin_len() includes length encoding bytes + match p.bin_len() as usize { + 0 => data_len += 1, // NULLs take 1 byte for the indicator + x => data_len += x + 1, // non-NULLs take their length + 1 byte for the indicator + } + } + + // 7 = 1(command id) + 4 (stmt_id) + 2 (flags). If it's 1st row - we take it to return error + // later(when the packet is sent). In this way we can avoid eternal loops of trying to add this row. + if 7 + self.payload_len + data_len > self.max_payload_len { + if self.params_set.is_empty() { + return Err(BulkExecuteRequestError::RowTooLarge(params).into()); + } + return Ok(Some(params)); + } + + self.params_set.push(params); + self.payload_len += data_len; + + Ok(None) + } + + pub fn has_rows(&self) -> bool { + !self.params_set.is_empty() + } + + /// Builds `COM_STMT_BULK_EXECUTE` consuming rows added so far. + /// + /// After the call you can continue using [`ComStmtBulkExecuteRequestBuilder::add_row`] + /// to build next bulk request for this statement. + /// + /// # Error + /// + /// This will error if no rows was added to the builder. + pub fn build(&mut self) -> Result, BulkExecuteRequestError> { + let bulk_flags = if self.with_types { + StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER + } else { + StmtBulkExecuteFlags::empty() + }; + self.with_types = false; + self.payload_len = 0; + ComStmtBulkExecuteRequest::new(self.stmt_id, bulk_flags, mem::take(&mut self.params_set)) + } + + /// See [`ComStmtBulkExecuteRequestBuilder::build_iter`]. + pub fn build_params_iter( + &mut self, + input: impl IntoIterator>, + ) -> impl Iterator, BulkExecuteRequestBuilderError>> + { + let mut done = false; + + macro_rules! transpose { + ($e:expr) => { + match $e { + Ok(x) => x, + Err(e) => { + done = true; + return Some(Err(e.into())); + } + } + }; + } + + let mut input = input.into_iter().map(Into::into); + let mut stack = None; + std::iter::from_fn(move || { + if done { + return None; + } + + let mut params_iter = stack + .take() + .map(Params::Positional) + .into_iter() + .chain(input.by_ref()); + + while let Some(params) = params_iter.next() { + if let Some(params) = transpose!(self.add_params(params)) { + stack = Some(params); + return Some(Ok(transpose!(self.build()))); + } + } + + if self.has_rows() { + return Some(Ok(transpose!(self.build()))); + } + + done = true; + None + }) + } + + /// It's a convenient wrapper over [`ComStmtBulkExecuteRequestBuilder::add_row`] + /// and [`ComStmtBulkExecuteRequestBuilder::build`] that converts a rows iterator + /// to a bulk requests iterator. + /// + /// # Error + /// + /// This won't error if the input iterator contains no rows — it'll just emit no bulk requests + /// but the iterator will emit an error if it encounters a row with different arity. + pub fn build_iter( + &mut self, + input: impl IntoIterator>, + ) -> impl Iterator, BulkExecuteRequestBuilderError>> + { + self.build_params_iter(input.into_iter().map(|x| Params::Positional(x))) + } +} + +define_header!( + ComStmtBulkExecuteHeader, + COM_STMT_BULK_EXECUTE, + InvalidComStmtBulkExecuteHeader +); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StmtBulkExecuteParamType { + r#type: Const, + flags: Const, +} + +impl StmtBulkExecuteParamType { + pub fn new(r#type: ColumnType, flags: StmtExecuteParamFlags) -> Self { + Self { + r#type: Const::new(r#type), + flags: Const::new(flags), + } + } + + pub fn from_value(value: &Value) -> Self { + let (r#type, flags) = match value { + Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()), + Value::Bytes(_) => ( + ColumnType::MYSQL_TYPE_VAR_STRING, + StmtExecuteParamFlags::empty(), + ), + Value::Int(_) => ( + ColumnType::MYSQL_TYPE_LONGLONG, + StmtExecuteParamFlags::empty(), + ), + Value::UInt(_) => ( + ColumnType::MYSQL_TYPE_LONGLONG, + StmtExecuteParamFlags::UNSIGNED, + ), + Value::Float(_) => (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()), + Value::Double(_) => ( + ColumnType::MYSQL_TYPE_DOUBLE, + StmtExecuteParamFlags::empty(), + ), + Value::Date(..) => ( + ColumnType::MYSQL_TYPE_DATETIME, + StmtExecuteParamFlags::empty(), + ), + Value::Time(..) => (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()), + }; + + Self { + r#type: Const::new(r#type), + flags: Const::new(flags), + } + } +} + +impl<'de> MyDeserialize<'de> for StmtBulkExecuteParamType { + const SIZE: Option = Some(3); + type Ctx = (); + + fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result { + Ok(Self { + r#type: buf.parse(())?, + flags: buf.parse(())?, + }) + } +} + +impl MySerialize for StmtBulkExecuteParamType { + fn serialize(&self, buf: &mut Vec) { + self.r#type.serialize(&mut *buf); + self.flags.serialize(&mut *buf); + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct StmtBulkExecuteParamValues { + values: Vec, +} + +impl StmtBulkExecuteParamValues { + pub fn new(values: impl IntoIterator) -> Self { + let values = values + .into_iter() + .map(StmtBulkExecuteParamValue::new) + .collect(); + Self { values } + } +} + +impl<'de> MyDeserialize<'de> for StmtBulkExecuteParamValues { + const SIZE: Option = None; + type Ctx = Vec<(ColumnType, ColumnFlags)>; + + fn deserialize(params: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result { + let mut values = Vec::with_capacity(params.len()); + for param in params { + values.push(StmtBulkExecuteParamValue::deserialize(param, &mut *buf)?); + } + Ok(Self { values }) + } +} + +impl MySerialize for StmtBulkExecuteParamValues { + fn serialize(&self, buf: &mut Vec) { + for value in &self.values { + value.serialize(&mut *buf); + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct StmtBulkExecuteParamValue { + indicator: Const, + value: Value, +} + +impl StmtBulkExecuteParamValue { + pub fn new(value: Value) -> Self { + let indicator = if matches!(value, Value::NULL) { + MariadbBulkIndicator::BULK_INDICATOR_NULL + } else { + MariadbBulkIndicator::BULK_INDICATOR_NONE + }; + Self { + indicator: Const::new(indicator), + value, + } + } +} + +impl<'de> MyDeserialize<'de> for StmtBulkExecuteParamValue { + const SIZE: Option = None; + type Ctx = (ColumnType, ColumnFlags); + + fn deserialize(ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result { + let indicator = buf.parse::>(())?; + let value = if *indicator == MariadbBulkIndicator::BULK_INDICATOR_NULL { + Value::NULL + } else { + ValueDeserializer::::deserialize(ctx, buf)?.0 + }; + + Ok(Self { indicator, value }) + } +} + +impl MySerialize for StmtBulkExecuteParamValue { + fn serialize(&self, buf: &mut Vec) { + self.indicator.serialize(&mut *buf); + self.value.serialize(&mut *buf); + } +} + +#[derive(Debug, Clone, PartialEq, thiserror::Error)] +pub enum BulkExecuteRequestError { + #[error("No parameters values given for the bulk operation")] + NoParams, + #[error("Mixed statement arity")] + MixedArity, + #[error("Got row bigger than max payload length")] + RowTooLarge(Vec), +} + +#[derive(Debug, Clone, PartialEq, thiserror::Error)] +pub enum BulkExecuteRequestBuilderError { + #[error(transparent)] + Request(#[from] BulkExecuteRequestError), + #[error(transparent)] + Params(#[from] ParamsError), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ComStmtBulkExecuteRequest<'a> { + header: ComStmtBulkExecuteHeader, + stmt_id: RawInt, + bulk_flags: Const, + types: Seq<'a, StmtBulkExecuteParamType, Unknown>, + values: StmtBulkExecuteParamValues, +} + +impl<'a> ComStmtBulkExecuteRequest<'a> { + pub fn new( + stmt_id: u32, + bulk_flags: StmtBulkExecuteFlags, + values: Vec>, + ) -> Result { + let first = values.first().ok_or(BulkExecuteRequestError::NoParams)?; + let arity = first.len(); + + let types = if bulk_flags.contains(StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER) { + Seq::new( + first + .iter() + .map(StmtBulkExecuteParamType::from_value) + .collect::>(), + ) + } else { + Seq::empty() + }; + + for values in &values { + if values.len() != arity { + return Err(BulkExecuteRequestError::MixedArity); + } + } + + Ok(Self { + header: ConstU8::new(), + stmt_id: RawInt::new(stmt_id), + bulk_flags: Const::new(bulk_flags), + types, + values: StmtBulkExecuteParamValues::new(values.into_iter().flatten()), + }) + } +} + +impl<'de> MyDeserialize<'de> for ComStmtBulkExecuteRequest<'de> { + const SIZE: Option = None; + type Ctx = Vec<(ColumnType, ColumnFlags)>; + + fn deserialize(mut params: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result { + let header = buf.parse(())?; + let stmt_id = buf.parse(())?; + let bulk_flags: Const = buf.parse(())?; + let types = if bulk_flags.contains(StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER) { + let types: Seq<'_, StmtBulkExecuteParamType, Unknown> = buf.parse(params.len())?; + for (s_type, r#type) in types.iter().zip(params.iter_mut()) { + r#type.0 = *s_type.r#type; + r#type.1.set( + ColumnFlags::UNSIGNED_FLAG, + s_type.flags.contains(StmtExecuteParamFlags::UNSIGNED), + ); + } + types + } else { + Seq::empty() + }; + let values = buf.parse(params)?; + + Ok(Self { + header, + stmt_id, + bulk_flags, + types, + values, + }) + } +} + +impl<'a> MySerialize for ComStmtBulkExecuteRequest<'a> { + fn serialize(&self, buf: &mut Vec) { + self.header.serialize(&mut *buf); + self.stmt_id.serialize(&mut *buf); + self.bulk_flags.serialize(&mut *buf); + if self + .bulk_flags + .contains(StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER) + { + self.types.serialize(&mut *buf); + } + self.values.serialize(&mut *buf); + } +} + define_header!( ComRegisterSlaveHeader, COM_REGISTER_SLAVE, @@ -4129,7 +4599,7 @@ mod test { fn should_parse_handshake_packet_with_mariadb_ext_capabilities() { const HSP: &[u8] = b"\x0a5.5.5-11.4.7-MariaDB-log\x00\x0b\x00\ \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\ - \x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x2a\x34\x64\ + \x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x00\x00\x2a\x34\x64\ \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00"; let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap(); @@ -4150,6 +4620,7 @@ mod test { assert_eq!( hsp.mariadb_ext_capabilities(), MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA + | MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS ); let mut output = Vec::new(); hsp.serialize(&mut output); @@ -4169,7 +4640,10 @@ mod test { None, 1_u32.to_be(), ) - .with_mariadb_ext_capabilities(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA); + .with_mariadb_ext_capabilities( + MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA + | MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS, + ); let mut actual = Vec::new(); response.serialize(&mut actual); @@ -4179,7 +4653,7 @@ mod test { 0x2d, // charset 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved - 0x10, 0x00, 0x00, 0x00, // mariadb capabilities + 0x14, 0x00, 0x00, 0x00, // mariadb capabilities 0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root 0x00, // blank scramble 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, diff --git a/src/params.rs b/src/params.rs index 24cbf2c..70773e1 100644 --- a/src/params.rs +++ b/src/params.rs @@ -11,30 +11,30 @@ use std::{ HashMap, hash_map::{Entry, Entry::Occupied}, }, - error::Error, fmt, }; use crate::value::{Value, convert::ToValue}; -/// `FromValue` conversion error. -#[derive(Debug, Eq, PartialEq, Clone)] +/// Missing named parameter for a statement +#[derive(Debug, Eq, PartialEq, Clone, thiserror::Error)] +#[error("Missing named parameter `{}` for statement", String::from_utf8_lossy(&_0))] pub struct MissingNamedParameterError(pub Vec); -impl fmt::Display for MissingNamedParameterError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "Missing named parameter `{}` for statement", - String::from_utf8_lossy(&self.0) - ) - } +#[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error, Hash)] +pub enum ParamsConfusionError { + #[error("Named params given where positional params are expected")] + NamedParamsForPositionalQuery, + #[error("Positional params given where named params are expected")] + PositionalParamsForNamedQuery, } -impl Error for MissingNamedParameterError { - fn description(&self) -> &str { - "Missing named parameter for statement" - } +#[derive(Debug, PartialEq, Eq, Clone, thiserror::Error)] +pub enum ParamsError { + #[error(transparent)] + Missing(#[from] MissingNamedParameterError), + #[error(transparent)] + Confusion(#[from] ParamsConfusionError), } /// Representations of parameters of a prepared statement. @@ -62,8 +62,77 @@ impl fmt::Debug for Params { } impl Params { + /// Converts [`Params`] into a vector of values given the named parameters. + /// + /// `named_params` (if any) must follow the order they were given in the corresponding SQL + /// statement. + pub fn into_values(self, named_params: Option<&[Vec]>) -> Result, ParamsError> { + match self { + Params::Empty => match named_params { + Some(params) => { + if let Some(first) = params.first() { + Err(MissingNamedParameterError(first.clone()).into()) + } else { + Ok(vec![]) + } + } + None => Ok(vec![]), + }, + Params::Positional(values) => match named_params { + Some(named_params) if !named_params.is_empty() => { + Err(ParamsConfusionError::PositionalParamsForNamedQuery.into()) + } + _ => Ok(values), + }, + Params::Named(map) => match named_params { + Some(named_params) if !named_params.is_empty() => { + let mut values = vec![Value::NULL; named_params.len()]; + let mut indexes = Vec::with_capacity(named_params.len()); + for (name, value) in map { + let mut first = None; + for (i, _) in named_params.iter().enumerate().filter(|(_, x)| **x == name) { + indexes.push(i); + if first.is_none() { + first = Some(i); + } else { + values[i] = value.clone(); + } + } + if let Some(first) = first { + values[first] = value; + } + } + if indexes.len() != named_params.len() { + indexes.sort_unstable(); + match indexes.into_iter().enumerate().find(|x| x.0 != x.1) { + Some((missing, _)) => { + Err(MissingNamedParameterError(named_params[missing].clone()) + .into()) + } + None => { + match named_params.last() { + Some(last) => { + Err(MissingNamedParameterError(last.clone()).into()) + } + None => { + // unreachable + Ok(values) + } + } + } + } + } else { + Ok(values) + } + } + _ => Err(ParamsConfusionError::NamedParamsForPositionalQuery.into()), + }, + } + } + /// Will convert named parameters into positional assuming order passed in `named_params` /// attribute. + #[deprecated = "use `into_values` instead"] pub fn into_positional( self, named_params: &[Vec],