diff --git a/Cargo.lock b/Cargo.lock index 9e7d767d..daa4a375 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1859,6 +1859,7 @@ name = "resp" version = "0.1.0" dependencies = [ "bytes", + "memchr", "nom 8.0.0", "thiserror 1.0.69", ] diff --git a/src/resp/Cargo.toml b/src/resp/Cargo.toml index b67ec1f7..1264a74a 100644 --- a/src/resp/Cargo.toml +++ b/src/resp/Cargo.toml @@ -9,4 +9,5 @@ workspace = true [dependencies] bytes.workspace = true thiserror.workspace = true -nom.workspace = true \ No newline at end of file +nom.workspace = true +memchr = "2" \ No newline at end of file diff --git a/src/resp/src/compat.rs b/src/resp/src/compat.rs new file mode 100644 index 00000000..d2fb2468 --- /dev/null +++ b/src/resp/src/compat.rs @@ -0,0 +1,73 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// Defines how Boolean values should be converted when encoding to older RESP versions. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BooleanMode { + /// Convert Boolean to Integer: true → :1, false → :0 + Integer, + /// Convert Boolean to Simple String: true → +OK, false → +ERR + SimpleString, +} + +/// Defines how Double values should be converted when encoding to older RESP versions. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DoubleMode { + /// Convert Double to Bulk String: 3.14 → $4\r\n3.14\r\n + BulkString, + /// Convert Double to Integer if whole number: 2.0 → :2, 2.5 → BulkString + IntegerIfWhole, +} + +/// Defines how Map values should be converted when encoding to older RESP versions. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum MapMode { + /// Convert Map to flat Array: Map{k1:v1,k2:v2} → *4\r\nk1\r\nv1\r\nk2\r\nv2\r\n + FlatArray, + /// Convert Map to Array of pairs: Map{k1:v1,k2:v2} → *2\r\n*2\r\nk1\r\nv1\r\n*2\r\nk2\r\nv2\r\n + ArrayOfPairs, +} + +/// Configuration for converting RESP3 types to older RESP versions. +/// +/// This policy defines how RESP3-specific types (Boolean, Double, Map, etc.) +/// should be represented when encoding to RESP1 or RESP2 protocols. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DownlevelPolicy { + /// How to convert Boolean values + pub boolean_mode: BooleanMode, + /// How to convert Double values + pub double_mode: DoubleMode, + /// How to convert Map values + pub map_mode: MapMode, +} + +impl Default for DownlevelPolicy { + /// Creates a default downlevel policy with conservative conversion settings. + /// + /// Default settings: + /// - Boolean → Integer (true → :1, false → :0) + /// - Double → BulkString (3.14 → $4\r\n3.14\r\n) + /// - Map → FlatArray (Map{k:v} → *2\r\nk\r\nv\r\n) + fn default() -> Self { + Self { + boolean_mode: BooleanMode::Integer, + double_mode: DoubleMode::BulkString, + map_mode: MapMode::FlatArray, + } + } +} diff --git a/src/resp/src/encode.rs b/src/resp/src/encode.rs index 4bc6e1c2..4d8c3ed6 100644 --- a/src/resp/src/encode.rs +++ b/src/resp/src/encode.rs @@ -370,6 +370,12 @@ impl RespEncode for RespEncoder { } self.append_crlf() } + // RESP3-only variants (Null, Boolean, Double, BulkError, VerbatimString, + // BigNumber, Map, Set, Push) are not encoded in the legacy RESP2 encoder + // and are silently skipped. These should be handled by version-specific + // encoders (Resp1Encoder, Resp2Encoder, Resp3Encoder) with appropriate + // downlevel conversion policies. + _ => self, } } } diff --git a/src/resp/src/factory.rs b/src/resp/src/factory.rs new file mode 100644 index 00000000..2ed594a0 --- /dev/null +++ b/src/resp/src/factory.rs @@ -0,0 +1,103 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Factory functions for creating RESP protocol encoders and decoders. +//! +//! This module provides convenient factory functions to create version-specific +//! encoder and decoder instances based on the desired RESP protocol version. + +use crate::{ + compat::DownlevelPolicy, + traits::{Decoder, Encoder}, + types::RespVersion, +}; + +/// Create a new decoder for the specified RESP version. +/// +/// # Arguments +/// * `version` - The RESP protocol version to create a decoder for +/// +/// # Returns +/// A boxed decoder instance that implements the `Decoder` trait +/// +/// # Examples +/// ``` +/// use resp::{RespVersion, new_decoder}; +/// +/// let mut decoder = new_decoder(RespVersion::RESP3); +/// decoder.push("+OK\r\n".into()); +/// ``` +pub fn new_decoder(version: RespVersion) -> Box { + match version { + RespVersion::RESP1 => Box::new(crate::resp1::decoder::Resp1Decoder::default()), + RespVersion::RESP2 => Box::new(crate::resp2::decoder::Resp2Decoder::default()), + RespVersion::RESP3 => Box::new(crate::resp3::decoder::Resp3Decoder::default()), + } +} + +/// Create a new encoder for the specified RESP version. +/// +/// # Arguments +/// * `version` - The RESP protocol version to create an encoder for +/// +/// # Returns +/// A boxed encoder instance that implements the `Encoder` trait +/// +/// # Examples +/// ``` +/// use resp::{RespData, RespVersion, new_encoder}; +/// +/// let mut encoder = new_encoder(RespVersion::RESP3); +/// let bytes = encoder.encode_one(&RespData::Boolean(true)).unwrap(); +/// assert_eq!(bytes.as_ref(), b"#t\r\n"); +/// ``` +pub fn new_encoder(version: RespVersion) -> Box { + match version { + RespVersion::RESP1 => Box::new(crate::resp1::encoder::Resp1Encoder::default()), + RespVersion::RESP2 => Box::new(crate::resp2::encoder::Resp2Encoder::default()), + RespVersion::RESP3 => Box::new(crate::resp3::encoder::Resp3Encoder::default()), + } +} + +/// Create a new encoder for the specified RESP version with custom downlevel policy. +/// +/// # Arguments +/// * `version` - The RESP protocol version to create an encoder for +/// * `policy` - The downlevel compatibility policy for RESP3 types +/// +/// # Returns +/// A boxed encoder instance that implements the `Encoder` trait +/// +/// # Examples +/// ``` +/// use resp::{BooleanMode, DownlevelPolicy, RespData, RespVersion, new_encoder_with_policy}; +/// +/// let policy = DownlevelPolicy { +/// boolean_mode: BooleanMode::SimpleString, +/// ..Default::default() +/// }; +/// let mut encoder = new_encoder_with_policy(RespVersion::RESP2, policy); +/// let bytes = encoder.encode_one(&RespData::Boolean(true)).unwrap(); // "+OK\r\n" +/// assert_eq!(bytes.as_ref(), b"+OK\r\n"); +/// ``` +pub fn new_encoder_with_policy(version: RespVersion, policy: DownlevelPolicy) -> Box { + match version { + RespVersion::RESP1 => Box::new(crate::resp1::encoder::Resp1Encoder::with_policy(policy)), + RespVersion::RESP2 => Box::new(crate::resp2::encoder::Resp2Encoder::with_policy(policy)), + RespVersion::RESP3 => Box::new(crate::resp3::encoder::Resp3Encoder::default()), + } +} diff --git a/src/resp/src/lib.rs b/src/resp/src/lib.rs index 38ae6196..6dbd0d3a 100644 --- a/src/resp/src/lib.rs +++ b/src/resp/src/lib.rs @@ -16,15 +16,30 @@ // limitations under the License. pub mod command; +pub mod compat; pub mod encode; pub mod error; pub mod parse; pub mod types; +// Versioned modules +pub mod resp1; +pub mod resp2; +pub mod resp3; + +// Unified traits and helpers +pub mod factory; +pub mod multi; +pub mod traits; + pub use command::{Command, CommandType, RespCommand}; +pub use compat::{BooleanMode, DoubleMode, DownlevelPolicy, MapMode}; pub use encode::{CmdRes, RespEncode}; pub use error::{RespError, RespResult}; +pub use factory::{new_decoder, new_encoder, new_encoder_with_policy}; +pub use multi::{decode_many, encode_many}; pub use parse::{Parse, RespParse, RespParseResult}; +pub use traits::{Decoder, Encoder}; pub use types::{RespData, RespType, RespVersion}; pub const CRLF: &str = "\r\n"; diff --git a/src/resp/src/multi.rs b/src/resp/src/multi.rs new file mode 100644 index 00000000..30725083 --- /dev/null +++ b/src/resp/src/multi.rs @@ -0,0 +1,89 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Utilities for batch processing of RESP messages. +//! +//! This module provides functions for encoding and decoding multiple RESP messages +//! in a single operation, which is useful for pipelining and batch operations. + +use bytes::{Bytes, BytesMut}; + +use crate::{ + error::RespResult, + traits::{Decoder, Encoder}, + types::RespData, +}; + +/// Decode multiple RESP messages from a single byte chunk. +/// +/// This function pushes the entire chunk into the decoder and attempts to parse +/// all complete messages available. Useful for processing pipelined commands. +/// +/// # Arguments +/// * `decoder` - The decoder instance to use for parsing +/// * `chunk` - The byte chunk containing one or more RESP messages +/// +/// # Returns +/// A vector of parsing results. Each element is either `Ok(RespData)` for a +/// successfully parsed message, or `Err(RespError)` for parsing errors. +/// +/// # Examples +/// ``` +/// use resp::{RespVersion, decode_many, new_decoder}; +/// +/// let mut decoder = new_decoder(RespVersion::RESP2); +/// let results = decode_many(&mut *decoder, "+OK\r\n:42\r\n".into()); +/// assert_eq!(results.len(), 2); +/// ``` +pub fn decode_many(decoder: &mut dyn Decoder, chunk: Bytes) -> Vec> { + decoder.push(chunk); + let mut out = Vec::new(); + while let Some(frame) = decoder.next() { + out.push(frame); + } + out +} + +/// Encode multiple RESP messages into a single byte buffer. +/// +/// This function encodes each `RespData` value in sequence and concatenates +/// the results into a single `Bytes` buffer. Useful for building pipelined commands. +/// +/// # Arguments +/// * `encoder` - The encoder instance to use for encoding +/// * `values` - A slice of `RespData` values to encode +/// +/// # Returns +/// A `Result` containing the concatenated encoded bytes, or an error if encoding fails. +/// +/// # Examples +/// ``` +/// use resp::{RespData, RespVersion, encode_many, new_encoder}; +/// +/// let mut encoder = new_encoder(RespVersion::RESP2); +/// let values = vec![RespData::SimpleString("OK".into()), RespData::Integer(42)]; +/// let bytes = encode_many(&mut *encoder, &values).unwrap(); +/// // bytes contains "+OK\r\n:42\r\n" +/// assert_eq!(bytes.as_ref(), b"+OK\r\n:42\r\n"); +/// ``` +pub fn encode_many(encoder: &mut dyn Encoder, values: &[RespData]) -> RespResult { + let mut buf = BytesMut::with_capacity(values.len() * 32); // Rough estimate + for v in values { + encoder.encode_into(v, &mut buf)?; + } + Ok(buf.freeze()) +} diff --git a/src/resp/src/parse.rs b/src/resp/src/parse.rs index a67f2676..0d1970af 100644 --- a/src/resp/src/parse.rs +++ b/src/resp/src/parse.rs @@ -35,7 +35,7 @@ use crate::{ types::{RespData, RespVersion}, }; -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq)] pub enum RespParseResult { Complete(RespData), Incomplete, diff --git a/src/resp/src/resp1/decoder.rs b/src/resp/src/resp1/decoder.rs new file mode 100644 index 00000000..275b5841 --- /dev/null +++ b/src/resp/src/resp1/decoder.rs @@ -0,0 +1,69 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; + +use bytes::Bytes; + +use crate::{ + error::RespResult, + parse::{Parse, RespParse, RespParseResult}, + traits::Decoder, + types::{RespData, RespVersion}, +}; + +#[derive(Default)] +pub struct Resp1Decoder { + inner: RespParse, + out: VecDeque>, +} + +impl Resp1Decoder { + pub fn new() -> Self { + Self { + inner: RespParse::new(RespVersion::RESP1), + out: VecDeque::new(), + } + } +} + +impl Decoder for Resp1Decoder { + fn push(&mut self, data: Bytes) { + let mut res = self.inner.parse(data); + loop { + match res { + RespParseResult::Complete(d) => self.out.push_back(Ok(d)), + RespParseResult::Error(e) => self.out.push_back(Err(e)), + RespParseResult::Incomplete => break, + } + res = self.inner.parse(Bytes::new()); + } + } + + fn next(&mut self) -> Option> { + self.out.pop_front() + } + + fn reset(&mut self) { + self.inner.reset(); + self.out.clear(); + } + + fn version(&self) -> RespVersion { + RespVersion::RESP1 + } +} diff --git a/src/resp/src/resp1/encoder.rs b/src/resp/src/resp1/encoder.rs new file mode 100644 index 00000000..603fe55e --- /dev/null +++ b/src/resp/src/resp1/encoder.rs @@ -0,0 +1,144 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::{Bytes, BytesMut}; + +use crate::{ + compat::{BooleanMode, DoubleMode, DownlevelPolicy, MapMode}, + encode::{RespEncode, RespEncoder}, + error::RespResult, + traits::Encoder, + types::{RespData, RespVersion}, +}; + +#[derive(Default)] +pub struct Resp1Encoder { + inner: RespEncoder, + policy: DownlevelPolicy, +} + +impl Resp1Encoder { + pub fn new() -> Self { + Self { + inner: RespEncoder::new(RespVersion::RESP1), + policy: DownlevelPolicy::default(), + } + } + + pub fn with_policy(policy: DownlevelPolicy) -> Self { + Self { + inner: RespEncoder::new(RespVersion::RESP1), + policy, + } + } +} + +impl Encoder for Resp1Encoder { + fn encode_one(&mut self, data: &RespData) -> RespResult { + self.inner.clear(); + self.encode_downleveled(data); + Ok(self.inner.get_response()) + } + + fn encode_into(&mut self, data: &RespData, out: &mut BytesMut) -> RespResult<()> { + let bytes = self.encode_one(data)?; + out.extend_from_slice(&bytes); + Ok(()) + } + + fn version(&self) -> RespVersion { + RespVersion::RESP1 + } +} + +impl Resp1Encoder { + fn encode_downleveled(&mut self, data: &RespData) { + match data { + // RESP1 uses inline/simple/integer/bulk/array semantics basically consistent + RespData::SimpleString(_) + | RespData::Error(_) + | RespData::Integer(_) + | RespData::BulkString(_) + | RespData::Array(_) + | RespData::Inline(_) => { + self.inner.encode_resp_data(data); + } + // Downlevel mapping strategy consistent with RESP2 + RespData::Null => { + self.inner.set_line_string("$-1"); + } + RespData::Boolean(b) => { + match self.policy.boolean_mode { + BooleanMode::Integer => self.inner.append_integer(if *b { 1 } else { 0 }), + BooleanMode::SimpleString => { + if *b { + self.inner.append_simple_string("OK") + } else { + self.inner.append_simple_string("ERR") + } + } + }; + } + RespData::Double(v) => { + if let DoubleMode::IntegerIfWhole = self.policy.double_mode { + if v.fract() == 0.0 + && v.is_finite() + && *v >= i64::MIN as f64 + && *v <= i64::MAX as f64 + { + self.inner.append_integer(*v as i64); + return; + } + } + self.inner.append_string(&format!("{}", v)); + } + crate::types::RespData::BulkError(msg) => { + self.inner + .append_string_raw(&format!("-{}\r\n", String::from_utf8_lossy(msg))); + } + crate::types::RespData::VerbatimString { data, .. } => { + self.inner.append_bulk_string(data); + } + crate::types::RespData::BigNumber(s) => { + self.inner.append_string(s); + } + crate::types::RespData::Map(entries) => match self.policy.map_mode { + MapMode::FlatArray => { + self.inner.append_array_len((entries.len() * 2) as i64); + for (k, v) in entries { + self.encode_downleveled(k); + self.encode_downleveled(v); + } + } + MapMode::ArrayOfPairs => { + self.inner.append_array_len(entries.len() as i64); + for (k, v) in entries { + self.inner.append_array_len(2); + self.encode_downleveled(k); + self.encode_downleveled(v); + } + } + }, + crate::types::RespData::Set(items) | crate::types::RespData::Push(items) => { + self.inner.append_array_len(items.len() as i64); + for it in items { + self.encode_downleveled(it); + } + } + } + } +} diff --git a/src/resp/src/resp1/mod.rs b/src/resp/src/resp1/mod.rs new file mode 100644 index 00000000..45efdd5e --- /dev/null +++ b/src/resp/src/resp1/mod.rs @@ -0,0 +1,19 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod decoder; +pub mod encoder; diff --git a/src/resp/src/resp2/decoder.rs b/src/resp/src/resp2/decoder.rs new file mode 100644 index 00000000..aa2eaa10 --- /dev/null +++ b/src/resp/src/resp2/decoder.rs @@ -0,0 +1,69 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; + +use bytes::Bytes; + +use crate::{ + error::RespResult, + parse::{Parse, RespParse, RespParseResult}, + traits::Decoder, + types::{RespData, RespVersion}, +}; + +#[derive(Default)] +pub struct Resp2Decoder { + inner: RespParse, + out: VecDeque>, +} + +impl Resp2Decoder { + pub fn new() -> Self { + Self { + inner: RespParse::new(RespVersion::RESP2), + out: VecDeque::new(), + } + } +} + +impl Decoder for Resp2Decoder { + fn push(&mut self, data: Bytes) { + let mut res = self.inner.parse(data); + loop { + match res { + RespParseResult::Complete(d) => self.out.push_back(Ok(d)), + RespParseResult::Error(e) => self.out.push_back(Err(e)), + RespParseResult::Incomplete => break, + } + res = self.inner.parse(Bytes::new()); + } + } + + fn next(&mut self) -> Option> { + self.out.pop_front() + } + + fn reset(&mut self) { + self.inner.reset(); + self.out.clear(); + } + + fn version(&self) -> RespVersion { + RespVersion::RESP2 + } +} diff --git a/src/resp/src/resp2/encoder.rs b/src/resp/src/resp2/encoder.rs new file mode 100644 index 00000000..bc6478e1 --- /dev/null +++ b/src/resp/src/resp2/encoder.rs @@ -0,0 +1,144 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::{Bytes, BytesMut}; + +use crate::{ + compat::{BooleanMode, DoubleMode, DownlevelPolicy, MapMode}, + encode::{RespEncode, RespEncoder}, + error::RespResult, + traits::Encoder, + types::{RespData, RespVersion}, +}; + +#[derive(Default)] +pub struct Resp2Encoder { + inner: RespEncoder, + policy: DownlevelPolicy, +} + +impl Resp2Encoder { + pub fn new() -> Self { + Self { + inner: RespEncoder::new(RespVersion::RESP2), + policy: DownlevelPolicy::default(), + } + } + + pub fn with_policy(policy: DownlevelPolicy) -> Self { + Self { + inner: RespEncoder::new(RespVersion::RESP2), + policy, + } + } +} + +impl Encoder for Resp2Encoder { + fn encode_one(&mut self, data: &RespData) -> RespResult { + self.inner.clear(); + self.encode_downleveled(data); + Ok(self.inner.get_response()) + } + + fn encode_into(&mut self, data: &RespData, out: &mut BytesMut) -> RespResult<()> { + let bytes = self.encode_one(data)?; + out.extend_from_slice(&bytes); + Ok(()) + } + + fn version(&self) -> RespVersion { + RespVersion::RESP2 + } +} + +impl Resp2Encoder { + fn encode_downleveled(&mut self, data: &RespData) { + match data { + // RESP2 native types + RespData::SimpleString(_) + | RespData::Error(_) + | RespData::Integer(_) + | RespData::BulkString(_) + | RespData::Array(_) + | RespData::Inline(_) => { + self.inner.encode_resp_data(data); + } + // Downlevel mappings + RespData::Null => { + self.inner.set_line_string("$-1"); + } + RespData::Boolean(b) => { + match self.policy.boolean_mode { + BooleanMode::Integer => self.inner.append_integer(if *b { 1 } else { 0 }), + BooleanMode::SimpleString => { + if *b { + self.inner.append_simple_string("OK") + } else { + self.inner.append_simple_string("ERR") + } + } + }; + } + RespData::Double(v) => { + if let DoubleMode::IntegerIfWhole = self.policy.double_mode { + if v.fract() == 0.0 + && v.is_finite() + && *v >= i64::MIN as f64 + && *v <= i64::MAX as f64 + { + self.inner.append_integer(*v as i64); + return; + } + } + self.inner.append_string(&format!("{}", v)); + } + RespData::BulkError(msg) => { + self.inner + .append_string_raw(&format!("-{}\r\n", String::from_utf8_lossy(msg))); + } + RespData::VerbatimString { data, .. } => { + self.inner.append_bulk_string(data); + } + RespData::BigNumber(s) => { + self.inner.append_string(s); + } + RespData::Map(entries) => match self.policy.map_mode { + MapMode::FlatArray => { + self.inner.append_array_len((entries.len() * 2) as i64); + for (k, v) in entries { + self.encode_downleveled(k); + self.encode_downleveled(v); + } + } + MapMode::ArrayOfPairs => { + self.inner.append_array_len(entries.len() as i64); + for (k, v) in entries { + self.inner.append_array_len(2); + self.encode_downleveled(k); + self.encode_downleveled(v); + } + } + }, + RespData::Set(items) | RespData::Push(items) => { + self.inner.append_array_len(items.len() as i64); + for it in items { + self.encode_downleveled(it); + } + } + } + } +} diff --git a/src/resp/src/resp2/mod.rs b/src/resp/src/resp2/mod.rs new file mode 100644 index 00000000..45efdd5e --- /dev/null +++ b/src/resp/src/resp2/mod.rs @@ -0,0 +1,19 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod decoder; +pub mod encoder; diff --git a/src/resp/src/resp3/decoder.rs b/src/resp/src/resp3/decoder.rs new file mode 100644 index 00000000..4cbb6ad2 --- /dev/null +++ b/src/resp/src/resp3/decoder.rs @@ -0,0 +1,750 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; + +use bytes::Bytes; + +use crate::{ + error::{RespError, RespResult}, + traits::Decoder, + types::{RespData, RespVersion}, +}; + +// Maximum allowed lengths to prevent DoS attacks +const MAX_BULK_LEN: usize = 512 * 1024 * 1024; // 512 MB +const MAX_BULK_ERROR_LEN: usize = MAX_BULK_LEN; // tune independently if needed +const MAX_VERBATIM_LEN: usize = MAX_BULK_LEN; // tune independently if needed +const MAX_BIGNUM_LEN: usize = 16 * 1024 * 1024; // 16 MB of digits +const MAX_INLINE_LEN: usize = 4 * 1024; // 4 KiB for inline commands +const MAX_ARRAY_LEN: usize = 1024 * 1024; // 1M elements +const MAX_MAP_PAIRS: usize = 1024 * 1024; // 1M pairs +const MAX_SET_LEN: usize = 1024 * 1024; // 1M elements +const MAX_PUSH_LEN: usize = 1024 * 1024; // 1M elements + +#[derive(Debug, Clone, Default)] +enum ParsingState { + #[default] + Idle, + Array { + expected_count: usize, + items: Vec, + }, + Map { + expected_pairs: usize, + items: Vec<(RespData, RespData)>, + }, + MapWaitingForValue { + expected_pairs: usize, + items: Vec<(RespData, RespData)>, + current_key: RespData, + }, + Set { + expected_count: usize, + items: Vec, + }, + Push { + expected_count: usize, + items: Vec, + }, +} + +#[derive(Default)] +pub struct Resp3Decoder { + out: VecDeque>, + buf: bytes::BytesMut, + state_stack: Vec, +} + +impl Resp3Decoder { + /// Continue parsing an ongoing collection + fn continue_collection_parsing(&mut self) -> Option> { + let current_state = self.state_stack.pop()?; + match current_state { + ParsingState::Idle => { + // Should not happen, but handle gracefully + None + } + ParsingState::Array { + expected_count, + mut items, + } => { + while items.len() < expected_count { + if let Some(result) = self.parse_single_value_for_collection() { + match result { + Ok(item) => items.push(item), + Err(e) => { + // Clear state stack on error + self.state_stack.clear(); + return Some(Err(e)); + } + } + } else { + // Need more data, restore state to stack + self.state_stack.push(ParsingState::Array { + expected_count, + items, + }); + return None; + } + } + // All elements parsed successfully + Some(Ok(RespData::Array(Some(items)))) + } + ParsingState::Map { + expected_pairs, + mut items, + } => { + while items.len() < expected_pairs { + // Parse key + if let Some(result) = self.parse_single_value_for_collection() { + match result { + Ok(key) => { + // Parse value + if let Some(result) = self.parse_single_value_for_collection() { + match result { + Ok(value) => items.push((key, value)), + Err(e) => { + // Clear state stack on error + self.state_stack.clear(); + return Some(Err(e)); + } + } + } else { + // Need more data for value, save key and wait + self.state_stack.push(ParsingState::MapWaitingForValue { + expected_pairs, + items, + current_key: key, + }); + return None; + } + } + Err(e) => { + // Clear state stack on error + self.state_stack.clear(); + return Some(Err(e)); + } + } + } else { + // Need more data for key, restore state to stack + self.state_stack.push(ParsingState::Map { + expected_pairs, + items, + }); + return None; + } + } + // All pairs parsed successfully + Some(Ok(RespData::Map(items))) + } + ParsingState::MapWaitingForValue { + expected_pairs, + mut items, + current_key, + } => { + // Parse value for the saved key + if let Some(result) = self.parse_single_value_for_collection() { + match result { + Ok(value) => { + items.push((current_key, value)); + // Continue with remaining pairs + self.state_stack.push(ParsingState::Map { + expected_pairs, + items, + }); + self.continue_collection_parsing() + } + Err(e) => { + // Clear state stack on error + self.state_stack.clear(); + Some(Err(e)) + } + } + } else { + // Need more data, restore state to stack + self.state_stack.push(ParsingState::MapWaitingForValue { + expected_pairs, + items, + current_key, + }); + None + } + } + ParsingState::Set { + expected_count, + mut items, + } => { + while items.len() < expected_count { + if let Some(result) = self.parse_single_value_for_collection() { + match result { + Ok(item) => items.push(item), + Err(e) => { + // Clear state stack on error + self.state_stack.clear(); + return Some(Err(e)); + } + } + } else { + // Need more data, restore state to stack + self.state_stack.push(ParsingState::Set { + expected_count, + items, + }); + return None; + } + } + // All elements parsed successfully + Some(Ok(RespData::Set(items))) + } + ParsingState::Push { + expected_count, + mut items, + } => { + while items.len() < expected_count { + if let Some(result) = self.parse_single_value_for_collection() { + match result { + Ok(item) => items.push(item), + Err(e) => { + // Clear state stack on error + self.state_stack.clear(); + return Some(Err(e)); + } + } + } else { + // Need more data, restore state to stack + self.state_stack.push(ParsingState::Push { + expected_count, + items, + }); + return None; + } + } + // All elements parsed successfully + Some(Ok(RespData::Push(items))) + } + } + } + + /// Parse a single RESP value (with state tracking for collections) + fn parse_single_value(&mut self) -> Option> { + // First, try to continue parsing any ongoing collection + if let Some(result) = self.continue_collection_parsing() { + return Some(result); + } + self.parse_single_value_atomic() + } + + /// Parse a single RESP value that can be part of a collection + fn parse_single_value_for_collection(&mut self) -> Option> { + // If we have ongoing collection parsing, continue with that + if !self.state_stack.is_empty() { + return self.continue_collection_parsing(); + } + // Otherwise parse atomically + self.parse_single_value_atomic() + } + + /// Parse a single RESP value atomically (without state tracking) + fn parse_single_value_atomic(&mut self) -> Option> { + if self.buf.is_empty() { + return None; + } + + match self.buf[0] { + // RESP3-specific types + b'_' => { + if self.buf.len() < 3 { + return None; + } + if &self.buf[..3] == b"_\r\n" { + let _ = self.buf.split_to(3); + Some(Ok(RespData::Null)) + } else { + Some(Err(RespError::ParseError("invalid RESP3 null".into()))) + } + } + b'#' => { + if self.buf.len() < 4 { + return None; + } + if &self.buf[..4] == b"#t\r\n" { + let _ = self.buf.split_to(4); + Some(Ok(RespData::Boolean(true))) + } else if &self.buf[..4] == b"#f\r\n" { + let _ = self.buf.split_to(4); + Some(Ok(RespData::Boolean(false))) + } else { + Some(Err(RespError::ParseError("invalid RESP3 boolean".into()))) + } + } + b',' => { + if let Some(pos) = memchr::memchr(b'\n', &self.buf) { + let line_len = pos + 1; + let line = &self.buf[..line_len]; + if line.len() < 3 || line[line.len() - 2] != b'\r' { + return None; + } + let chunk = self.buf.split_to(line_len); + let body = &chunk[1..chunk.len() - 2]; + if let Ok(s) = std::str::from_utf8(body) { + let sl = s.to_ascii_lowercase(); + let val = if sl == "inf" { + Some(f64::INFINITY) + } else if sl == "-inf" { + Some(f64::NEG_INFINITY) + } else if sl == "nan" { + Some(f64::NAN) + } else { + s.parse::().ok() + }; + match val { + Some(v) => Some(Ok(RespData::Double(v))), + None => Some(Err(RespError::ParseError("invalid RESP3 double".into()))), + } + } else { + Some(Err(RespError::ParseError("invalid RESP3 double".into()))) + } + } else { + None + } + } + b'!' => { + if let Some(nl) = memchr::memchr(b'\n', &self.buf) { + if nl < 3 || self.buf[nl - 1] != b'\r' { + return None; + } + let len_bytes = &self.buf[1..nl - 1]; + let len = match std::str::from_utf8(len_bytes) + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(v) => v, + None => { + return Some(Err(RespError::ParseError( + "invalid bulk error len".into(), + ))); + } + }; + if len > MAX_BULK_ERROR_LEN { + return Some(Err(RespError::ParseError( + "bulk error length exceeds limit".into(), + ))); + } + let need = nl + 1 + len + 2; + if self.buf.len() < need { + return None; + } + let chunk = self.buf.split_to(need); + if &chunk[nl + 1 + len..need] != b"\r\n" { + return Some(Err(RespError::ParseError( + "invalid bulk error ending".into(), + ))); + } + let data = bytes::Bytes::copy_from_slice(&chunk[nl + 1..nl + 1 + len]); + Some(Ok(RespData::BulkError(data))) + } else { + None + } + } + b'=' => { + if let Some(nl) = memchr::memchr(b'\n', &self.buf) { + if nl < 3 || self.buf[nl - 1] != b'\r' { + return None; + } + let len_bytes = &self.buf[1..nl - 1]; + let len = match std::str::from_utf8(len_bytes) + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(v) => v, + None => { + return Some(Err(RespError::ParseError("invalid verbatim len".into()))); + } + }; + if len > MAX_VERBATIM_LEN { + return Some(Err(RespError::ParseError( + "verbatim string length exceeds limit".into(), + ))); + } + let need = nl + 1 + len + 2; + if self.buf.len() < need { + return None; + } + let chunk = self.buf.split_to(need); + let content = &chunk[nl + 1..nl + 1 + len]; + if content.len() < 4 || content[3] != b':' { + return Some(Err(RespError::ParseError("invalid verbatim header".into()))); + } + let mut fmt = [0u8; 3]; + fmt.copy_from_slice(&content[0..3]); + let data = bytes::Bytes::copy_from_slice(&content[4..]); + if &chunk[nl + 1 + len..need] != b"\r\n" { + return Some(Err(RespError::ParseError("invalid verbatim ending".into()))); + } + Some(Ok(RespData::VerbatimString { format: fmt, data })) + } else { + None + } + } + b'(' => { + if let Some(pos) = memchr::memchr(b'\n', &self.buf) { + let line_len = pos + 1; + if line_len < 3 || self.buf[line_len - 2] != b'\r' { + return None; + } + // body length excludes '(' and CRLF + let body_len = line_len - 3; + if body_len > MAX_BIGNUM_LEN { + return Some(Err(RespError::ParseError( + "big number length exceeds limit".into(), + ))); + } + let chunk = self.buf.split_to(line_len); + let body = &chunk[1..chunk.len() - 2]; + match std::str::from_utf8(body) { + Ok(s) => Some(Ok(RespData::BigNumber(s.to_string()))), + Err(_) => Some(Err(RespError::ParseError("invalid big number".into()))), + } + } else { + None + } + } + // Standard RESP types + b'+' => { + if let Some(pos) = memchr::memchr(b'\n', &self.buf) { + let line_len = pos + 1; + let line = &self.buf[..line_len]; + if line.len() < 3 || line[line.len() - 2] != b'\r' { + return None; + } + let chunk = self.buf.split_to(line_len); + let data = bytes::Bytes::copy_from_slice(&chunk[1..chunk.len() - 2]); + Some(Ok(RespData::SimpleString(data))) + } else { + None + } + } + b'-' => { + if let Some(pos) = memchr::memchr(b'\n', &self.buf) { + let line_len = pos + 1; + let line = &self.buf[..line_len]; + if line.len() < 3 || line[line.len() - 2] != b'\r' { + return None; + } + let chunk = self.buf.split_to(line_len); + let data = bytes::Bytes::copy_from_slice(&chunk[1..chunk.len() - 2]); + Some(Ok(RespData::Error(data))) + } else { + None + } + } + b':' => { + if let Some(pos) = memchr::memchr(b'\n', &self.buf) { + let line_len = pos + 1; + let line = &self.buf[..line_len]; + if line.len() < 3 || line[line.len() - 2] != b'\r' { + return None; + } + let chunk = self.buf.split_to(line_len); + let num_str = &chunk[1..chunk.len() - 2]; + match std::str::from_utf8(num_str) + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(n) => Some(Ok(RespData::Integer(n))), + None => Some(Err(RespError::ParseError("invalid integer".into()))), + } + } else { + None + } + } + b'$' => { + if let Some(nl) = memchr::memchr(b'\n', &self.buf) { + if nl < 3 || self.buf[nl - 1] != b'\r' { + return None; + } + let len_bytes = &self.buf[1..nl - 1]; + let len = match std::str::from_utf8(len_bytes) + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(v) => v, + None => { + return Some(Err(RespError::ParseError( + "invalid bulk string len".into(), + ))); + } + }; + if len == -1 { + let _ = self.buf.split_to(nl + 1); + Some(Ok(RespData::BulkString(None))) + } else if len >= 0 { + let len_usize = len as usize; + if len_usize > MAX_BULK_LEN { + return Some(Err(RespError::ParseError( + "bulk string length exceeds limit".into(), + ))); + } + let need = nl + 1 + len_usize + 2; + if self.buf.len() < need { + return None; + } + let chunk = self.buf.split_to(need); + if &chunk[nl + 1 + len_usize..need] != b"\r\n" { + return Some(Err(RespError::ParseError( + "invalid bulk string ending".into(), + ))); + } + let data = + bytes::Bytes::copy_from_slice(&chunk[nl + 1..nl + 1 + len_usize]); + Some(Ok(RespData::BulkString(Some(data)))) + } else { + Some(Err(RespError::ParseError( + "negative bulk string len".into(), + ))) + } + } else { + None + } + } + b'*' => { + if let Some(nl) = memchr::memchr(b'\n', &self.buf) { + if nl < 3 || self.buf[nl - 1] != b'\r' { + return None; + } + let len_bytes = &self.buf[1..nl - 1]; + let len = match std::str::from_utf8(len_bytes) + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(v) => v, + None => { + return Some(Err(RespError::ParseError("invalid array len".into()))); + } + }; + if len == -1 { + let _ = self.buf.split_to(nl + 1); + Some(Ok(RespData::Array(None))) + } else if len >= 0 { + let len_usize = len as usize; + if len_usize > MAX_ARRAY_LEN { + return Some(Err(RespError::ParseError( + "array length exceeds limit".into(), + ))); + } + // Consume header and start array parsing state + let _ = self.buf.split_to(nl + 1); + self.state_stack.push(ParsingState::Array { + expected_count: len_usize, + items: Vec::with_capacity(len_usize), + }); + // Continue parsing elements - this will parse the array elements + // and return the complete array when done + self.continue_collection_parsing() + } else { + Some(Err(RespError::ParseError("negative array len".into()))) + } + } else { + None + } + } + b'%' => { + // Map: %\r\n... + if let Some(nl) = memchr::memchr(b'\n', &self.buf) { + if nl < 3 || self.buf[nl - 1] != b'\r' { + return None; + } + let len_bytes = &self.buf[1..nl - 1]; + let pairs = match std::str::from_utf8(len_bytes) + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(v) => v, + None => { + return Some(Err(RespError::ParseError("invalid map len".into()))); + } + }; + if pairs > MAX_MAP_PAIRS { + return Some(Err(RespError::ParseError("map pairs exceed limit".into()))); + } + // Consume header and start map parsing state + let _ = self.buf.split_to(nl + 1); + self.state_stack.push(ParsingState::Map { + expected_pairs: pairs, + items: Vec::with_capacity(pairs), + }); + // Continue parsing pairs + self.continue_collection_parsing() + } else { + None + } + } + b'~' => { + // Set: ~\r\n... + if let Some(nl) = memchr::memchr(b'\n', &self.buf) { + if nl < 3 || self.buf[nl - 1] != b'\r' { + return None; + } + let len_bytes = &self.buf[1..nl - 1]; + let count = match std::str::from_utf8(len_bytes) + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(v) => v, + None => { + return Some(Err(RespError::ParseError("invalid set len".into()))); + } + }; + if count > MAX_SET_LEN { + return Some(Err(RespError::ParseError( + "set length exceeds limit".into(), + ))); + } + // Consume header and start set parsing state + let _ = self.buf.split_to(nl + 1); + self.state_stack.push(ParsingState::Set { + expected_count: count, + items: Vec::with_capacity(count), + }); + // Continue parsing elements + self.continue_collection_parsing() + } else { + None + } + } + b'>' => { + // Push: >len\r\n... + if let Some(nl) = memchr::memchr(b'\n', &self.buf) { + if nl < 3 || self.buf[nl - 1] != b'\r' { + return None; + } + let len_bytes = &self.buf[1..nl - 1]; + let count = match std::str::from_utf8(len_bytes) + .ok() + .and_then(|s| s.parse::().ok()) + { + Some(v) => v, + None => { + return Some(Err(RespError::ParseError("invalid push len".into()))); + } + }; + if count > MAX_PUSH_LEN { + return Some(Err(RespError::ParseError( + "push length exceeds limit".into(), + ))); + } + // Consume header and start push parsing state + let _ = self.buf.split_to(nl + 1); + self.state_stack.push(ParsingState::Push { + expected_count: count, + items: Vec::with_capacity(count), + }); + // Continue parsing elements + self.continue_collection_parsing() + } else { + None + } + } + // Inline command (no prefix, just data followed by \r\n) + _ => { + // Check if this looks like an inline command (no prefix, ends with \r\n) + if let Some(pos) = memchr::memchr(b'\n', &self.buf) { + let line_len = pos + 1; + let line = &self.buf[..line_len]; + if line.len() >= 2 && line[line.len() - 2] == b'\r' { + // Enforce max length for inline commands + let data_len = line_len - 2; // Exclude \r\n + if data_len > MAX_INLINE_LEN { + return Some(Err(RespError::ParseError( + "inline command length exceeds limit".into(), + ))); + } + + // Reject if first byte is a known RESP prefix or non-printable + let first_byte = self.buf[0]; + if matches!( + first_byte, + b'+' | b'-' + | b':' + | b'$' + | b'*' + | b'%' + | b'~' + | b'>' + | b'_' + | b'#' + | b',' + | b'!' + | b'=' + | b'(' + ) || !(32..=126).contains(&first_byte) + { + return Some(Err(RespError::ParseError( + "invalid inline command prefix".into(), + ))); + } + + let chunk = self.buf.split_to(line_len); + let data = &chunk[..chunk.len() - 2]; // Remove \r\n + let parts: Vec = data + .split(|&b| b == b' ') + .map(bytes::Bytes::copy_from_slice) + .collect(); + Some(Ok(RespData::Inline(parts))) + } else { + Some(Err(RespError::ParseError("invalid inline command".into()))) + } + } else { + None + } + } + } + } +} + +impl Decoder for Resp3Decoder { + fn push(&mut self, data: Bytes) { + self.buf.extend_from_slice(&data); + + while let Some(result) = self.parse_single_value() { + match result { + Ok(data) => { + self.out.push_back(Ok(data)); + } + Err(e) => { + self.out.push_back(Err(e)); + break; + } + } + } + } + + fn next(&mut self) -> Option> { + self.out.pop_front() + } + + fn reset(&mut self) { + self.out.clear(); + self.buf.clear(); + self.state_stack.clear(); + } + + fn version(&self) -> RespVersion { + RespVersion::RESP3 + } +} diff --git a/src/resp/src/resp3/encoder.rs b/src/resp/src/resp3/encoder.rs new file mode 100644 index 00000000..bafec692 --- /dev/null +++ b/src/resp/src/resp3/encoder.rs @@ -0,0 +1,184 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::{Bytes, BytesMut}; + +use crate::{ + encode::{RespEncode, RespEncoder}, + error::RespResult, + traits::Encoder, + types::{RespData, RespVersion}, +}; + +#[derive(Default)] +pub struct Resp3Encoder { + resp2_encoder: Option, +} + +impl Encoder for Resp3Encoder { + fn encode_one(&mut self, data: &RespData) -> RespResult { + let mut buf = BytesMut::new(); + self.encode_into(data, &mut buf)?; + Ok(buf.freeze()) + } + + fn encode_into(&mut self, data: &RespData, out: &mut BytesMut) -> RespResult<()> { + match data { + RespData::Null => { + out.extend_from_slice(b"_\r\n"); + } + RespData::Boolean(true) => { + out.extend_from_slice(b"#t\r\n"); + } + RespData::Boolean(false) => { + out.extend_from_slice(b"#f\r\n"); + } + RespData::Double(v) => { + out.extend_from_slice(b","); + if v.is_infinite() { + if v.is_sign_positive() { + out.extend_from_slice(b"inf"); + } else { + out.extend_from_slice(b"-inf"); + } + } else if v.is_nan() { + out.extend_from_slice(b"nan"); + } else { + use core::fmt::Write as _; + let mut s = String::new(); + let _ = write!(&mut s, "{}", v); + out.extend_from_slice(s.as_bytes()); + } + out.extend_from_slice(b"\r\n"); + } + RespData::BulkError(b) => { + let len = b.len(); + out.extend_from_slice(b"!"); + out.extend_from_slice(len.to_string().as_bytes()); + out.extend_from_slice(b"\r\n"); + out.extend_from_slice(b); + out.extend_from_slice(b"\r\n"); + } + RespData::VerbatimString { format, data } => { + // fmt: exactly 3 bytes + let payload_len = 3 + 1 + data.len(); + out.extend_from_slice(b"="); + out.extend_from_slice(payload_len.to_string().as_bytes()); + out.extend_from_slice(b"\r\n"); + out.extend_from_slice(format); + out.extend_from_slice(b":"); + out.extend_from_slice(data); + out.extend_from_slice(b"\r\n"); + } + RespData::BigNumber(s) => { + out.extend_from_slice(b"("); + out.extend_from_slice(s.as_bytes()); + out.extend_from_slice(b"\r\n"); + } + RespData::Map(entries) => { + out.extend_from_slice(b"%"); + out.extend_from_slice(entries.len().to_string().as_bytes()); + out.extend_from_slice(b"\r\n"); + for (k, v) in entries { + self.encode_into(k, out)?; + self.encode_into(v, out)?; + } + } + RespData::Set(items) => { + out.extend_from_slice(b"~"); + out.extend_from_slice(items.len().to_string().as_bytes()); + out.extend_from_slice(b"\r\n"); + for it in items { + self.encode_into(it, out)?; + } + } + RespData::Push(items) => { + out.extend_from_slice(b">"); + out.extend_from_slice(items.len().to_string().as_bytes()); + out.extend_from_slice(b"\r\n"); + for it in items { + self.encode_into(it, out)?; + } + } + // Standard RESP types - delegate to RESP2 encoder + RespData::SimpleString(s) => { + if self.resp2_encoder.is_none() { + self.resp2_encoder = Some(RespEncoder::new(RespVersion::RESP2)); + } + let encoder = self.resp2_encoder.as_mut().unwrap(); + let s = std::str::from_utf8(s).map_err(|_| { + crate::error::RespError::ParseError("invalid UTF-8 in SimpleString".into()) + })?; + encoder.clear().append_simple_string(s); + out.extend_from_slice(&encoder.get_response()); + } + RespData::Error(s) => { + out.extend_from_slice(b"-"); + out.extend_from_slice(s); + out.extend_from_slice(b"\r\n"); + } + RespData::Integer(n) => { + if self.resp2_encoder.is_none() { + self.resp2_encoder = Some(RespEncoder::new(RespVersion::RESP2)); + } + let encoder = self.resp2_encoder.as_mut().unwrap(); + encoder.clear().append_integer(*n); + out.extend_from_slice(&encoder.get_response()); + } + RespData::BulkString(Some(s)) => { + if self.resp2_encoder.is_none() { + self.resp2_encoder = Some(RespEncoder::new(RespVersion::RESP2)); + } + let encoder = self.resp2_encoder.as_mut().unwrap(); + encoder.clear().append_bulk_string(s); + out.extend_from_slice(&encoder.get_response()); + } + RespData::BulkString(None) => { + out.extend_from_slice(b"$-1\r\n"); + } + RespData::Array(Some(items)) => { + if self.resp2_encoder.is_none() { + self.resp2_encoder = Some(RespEncoder::new(RespVersion::RESP2)); + } + let encoder = self.resp2_encoder.as_mut().unwrap(); + encoder.clear().append_array_len(items.len() as i64); + out.extend_from_slice(&encoder.get_response()); + for item in items { + self.encode_into(item, out)?; + } + return Ok(()); + } + RespData::Array(None) => { + out.extend_from_slice(b"*-1\r\n"); + } + RespData::Inline(parts) => { + for (i, part) in parts.iter().enumerate() { + if i > 0 { + out.extend_from_slice(b" "); + } + out.extend_from_slice(part); + } + out.extend_from_slice(b"\r\n"); + } + } + Ok(()) + } + + fn version(&self) -> RespVersion { + RespVersion::RESP3 + } +} diff --git a/src/resp/src/resp3/mod.rs b/src/resp/src/resp3/mod.rs new file mode 100644 index 00000000..45efdd5e --- /dev/null +++ b/src/resp/src/resp3/mod.rs @@ -0,0 +1,19 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod decoder; +pub mod encoder; diff --git a/src/resp/src/traits.rs b/src/resp/src/traits.rs new file mode 100644 index 00000000..aefb56d0 --- /dev/null +++ b/src/resp/src/traits.rs @@ -0,0 +1,70 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::{Bytes, BytesMut}; + +use crate::{ + error::RespResult, + types::{RespData, RespVersion}, +}; + +/// Trait for decoding RESP protocol messages from byte streams. +/// +/// This trait provides a streaming interface for parsing RESP messages incrementally. +/// Implementations should handle incomplete data gracefully and maintain parsing state. +pub trait Decoder { + /// Push new data into the decoder's internal buffer. + /// + /// This method should be called when new bytes are available for parsing. + /// The decoder will accumulate data until a complete RESP message can be parsed. + fn push(&mut self, data: Bytes); + + /// Attempt to parse the next complete RESP message from the buffer. + /// + /// Returns `Some(Ok(data))` if a complete message was parsed, + /// `Some(Err(error))` if parsing failed, or `None` if more data is needed. + fn next(&mut self) -> Option>; + + /// Reset the decoder's internal state. + /// + /// This clears any buffered data and resets the parser to its initial state. + fn reset(&mut self); + + /// Get the RESP version this decoder supports. + fn version(&self) -> RespVersion; +} + +/// Trait for encoding RESP protocol messages to byte streams. +/// +/// This trait provides methods for converting `RespData` into RESP-formatted bytes. +/// Implementations should handle version-specific encoding rules and downlevel compatibility. +pub trait Encoder { + /// Encode a single RESP message into a `Bytes` buffer. + /// + /// This method creates a new buffer and encodes the given data into it. + /// For RESP3 types being encoded to older versions, downlevel conversion is applied. + fn encode_one(&mut self, data: &RespData) -> RespResult; + + /// Encode a RESP message into an existing `BytesMut` buffer. + /// + /// This method appends the encoded data to the provided buffer. + /// Useful for building larger messages or implementing streaming encoders. + fn encode_into(&mut self, data: &RespData, out: &mut BytesMut) -> RespResult<()>; + + /// Get the RESP version this encoder supports. + fn version(&self) -> RespVersion; +} diff --git a/src/resp/src/types.rs b/src/resp/src/types.rs index b5adac78..960f7df4 100644 --- a/src/resp/src/types.rs +++ b/src/resp/src/types.rs @@ -24,6 +24,7 @@ pub enum RespVersion { RESP1, #[default] RESP2, + RESP3, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -34,6 +35,16 @@ pub enum RespType { BulkString, Array, Inline, + // RESP3 additions (used for inspection only) + Null, + Boolean, + Double, + BulkError, + VerbatimString, + BigNumber, + Map, + Set, + Push, } impl RespType { @@ -44,6 +55,15 @@ impl RespType { b':' => Some(RespType::Integer), b'$' => Some(RespType::BulkString), b'*' => Some(RespType::Array), + b'_' => Some(RespType::Null), + b'#' => Some(RespType::Boolean), + b',' => Some(RespType::Double), + b'!' => Some(RespType::BulkError), + b'=' => Some(RespType::VerbatimString), + b'(' => Some(RespType::BigNumber), + b'%' => Some(RespType::Map), + b'~' => Some(RespType::Set), + b'>' => Some(RespType::Push), _ => None, } } @@ -56,11 +76,20 @@ impl RespType { RespType::BulkString => Some(b'$'), RespType::Array => Some(b'*'), RespType::Inline => None, + RespType::Null => Some(b'_'), + RespType::Boolean => Some(b'#'), + RespType::Double => Some(b','), + RespType::BulkError => Some(b'!'), + RespType::VerbatimString => Some(b'='), + RespType::BigNumber => Some(b'('), + RespType::Map => Some(b'%'), + RespType::Set => Some(b'~'), + RespType::Push => Some(b'>'), } } } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq)] pub enum RespData { SimpleString(Bytes), Error(Bytes), @@ -68,6 +97,16 @@ pub enum RespData { BulkString(Option), Array(Option>), Inline(Vec), + // RESP3 additions (subset; full coverage to be added gradually) + Null, + Boolean(bool), + Double(f64), + BulkError(Bytes), + VerbatimString { format: [u8; 3], data: Bytes }, + BigNumber(String), + Map(Vec<(RespData, RespData)>), + Set(Vec), + Push(Vec), } impl Default for RespData { @@ -85,6 +124,15 @@ impl RespData { RespData::BulkString(_) => RespType::BulkString, RespData::Array(_) => RespType::Array, RespData::Inline(_) => RespType::Inline, + RespData::Null => RespType::Null, + RespData::Boolean(_) => RespType::Boolean, + RespData::Double(_) => RespType::Double, + RespData::BulkError(_) => RespType::BulkError, + RespData::VerbatimString { .. } => RespType::VerbatimString, + RespData::BigNumber(_) => RespType::BigNumber, + RespData::Map(_) => RespType::Map, + RespData::Set(_) => RespType::Set, + RespData::Push(_) => RespType::Push, } } @@ -165,6 +213,28 @@ impl fmt::Debug for RespData { write!(f, "{parts_str:?}")?; write!(f, ")") } + RespData::Null => write!(f, "Null"), + RespData::Boolean(b) => write!(f, "Boolean({b})"), + RespData::Double(d) => write!(f, "Double({d})"), + RespData::BulkError(bytes) => { + if let Ok(s) = std::str::from_utf8(bytes) { + write!(f, "BulkError(\"{s}\")") + } else { + write!(f, "BulkError({bytes:?})") + } + } + RespData::VerbatimString { format, data } => { + if let Ok(s) = std::str::from_utf8(data) { + let fmt = std::str::from_utf8(&format[..]).unwrap_or("???"); + write!(f, "VerbatimString({fmt}:{s})") + } else { + write!(f, "VerbatimString({:?}:{:?})", format, data) + } + } + RespData::BigNumber(s) => write!(f, "BigNumber({s})"), + RespData::Map(entries) => write!(f, "Map(len={})", entries.len()), + RespData::Set(items) => write!(f, "Set(len={})", items.len()), + RespData::Push(items) => write!(f, "Push(len={})", items.len()), } } } diff --git a/src/resp/tests/factory_selection.rs b/src/resp/tests/factory_selection.rs new file mode 100644 index 00000000..2bdbb2f0 --- /dev/null +++ b/src/resp/tests/factory_selection.rs @@ -0,0 +1,44 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespVersion, new_decoder, new_encoder}; + +#[test] +fn selects_resp1_impl() { + let mut dec = new_decoder(RespVersion::RESP1); + let enc = new_encoder(RespVersion::RESP1); + assert_eq!(dec.version(), RespVersion::RESP1); + assert_eq!(enc.version(), RespVersion::RESP1); + + // minimal smoke: inline ping + dec.push(Bytes::from("PING\r\n")); + // even if command extraction differs, API shape should not panic + let _ = dec.next(); +} + +#[test] +fn selects_resp2_impl() { + let mut dec = new_decoder(RespVersion::RESP2); + let enc = new_encoder(RespVersion::RESP2); + assert_eq!(dec.version(), RespVersion::RESP2); + assert_eq!(enc.version(), RespVersion::RESP2); + + // minimal smoke: +OK\r\n + dec.push(Bytes::from("+OK\r\n")); + let _ = dec.next(); +} diff --git a/src/resp/tests/incremental_parsing.rs b/src/resp/tests/incremental_parsing.rs new file mode 100644 index 00000000..baa8986d --- /dev/null +++ b/src/resp/tests/incremental_parsing.rs @@ -0,0 +1,174 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use resp::{RespData, RespVersion, new_decoder}; + +#[test] +fn incremental_array_parsing() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: "*2\r\n+foo\r\n" + decoder.push("*2\r\n+foo\r\n".into()); + let result = decoder.next(); + assert!(result.is_none(), "Should need more data for complete array"); + + // Second chunk: "+bar\r\n" + decoder.push("+bar\r\n".into()); + let result = decoder.next(); + assert!(result.is_some(), "Should have complete array now"); + + match result.unwrap() { + Ok(RespData::Array(Some(items))) => { + assert_eq!(items.len(), 2); + match &items[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"foo"), + _ => panic!("Expected SimpleString 'foo'"), + } + match &items[1] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"bar"), + _ => panic!("Expected SimpleString 'bar'"), + } + } + other => panic!("Expected Array with 2 items, got {:?}", other), + } +} + +#[test] +fn incremental_map_parsing() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: "%1\r\n+key\r\n" + decoder.push("%1\r\n+key\r\n".into()); + let result = decoder.next(); + assert!(result.is_none(), "Should need more data for complete map"); + + // Second chunk: "+value\r\n" + decoder.push("+value\r\n".into()); + let result = decoder.next(); + assert!(result.is_some(), "Should have complete map now"); + + match result.unwrap() { + Ok(RespData::Map(items)) => { + assert_eq!(items.len(), 1); + match &items[0] { + (RespData::SimpleString(k), RespData::SimpleString(v)) => { + assert_eq!(k.as_ref(), b"key"); + assert_eq!(v.as_ref(), b"value"); + } + _ => panic!("Expected (SimpleString, SimpleString) pair"), + } + } + other => panic!("Expected Map with 1 pair, got {:?}", other), + } +} + +#[test] +fn incremental_set_parsing() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: "~2\r\n+item1\r\n" + decoder.push("~2\r\n+item1\r\n".into()); + let result = decoder.next(); + assert!(result.is_none(), "Should need more data for complete set"); + + // Second chunk: "+item2\r\n" + decoder.push("+item2\r\n".into()); + let result = decoder.next(); + assert!(result.is_some(), "Should have complete set now"); + + match result.unwrap() { + Ok(RespData::Set(items)) => { + assert_eq!(items.len(), 2); + match &items[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"item1"), + _ => panic!("Expected SimpleString 'item1'"), + } + match &items[1] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"item2"), + _ => panic!("Expected SimpleString 'item2'"), + } + } + other => panic!("Expected Set with 2 items, got {:?}", other), + } +} + +#[test] +fn incremental_push_parsing() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: ">2\r\n+msg1\r\n" + decoder.push(">2\r\n+msg1\r\n".into()); + let result = decoder.next(); + assert!(result.is_none(), "Should need more data for complete push"); + + // Second chunk: "+msg2\r\n" + decoder.push("+msg2\r\n".into()); + let result = decoder.next(); + assert!(result.is_some(), "Should have complete push now"); + + match result.unwrap() { + Ok(RespData::Push(items)) => { + assert_eq!(items.len(), 2); + match &items[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"msg1"), + _ => panic!("Expected SimpleString 'msg1'"), + } + match &items[1] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"msg2"), + _ => panic!("Expected SimpleString 'msg2'"), + } + } + other => panic!("Expected Push with 2 items, got {:?}", other), + } +} + +#[test] +fn multiple_incremental_messages() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: "*1\r\n+hello\r\n*1\r\n" + decoder.push("*1\r\n+hello\r\n*1\r\n".into()); + let result = decoder.next(); + assert!(result.is_some(), "Should have first complete array"); + + match result.unwrap() { + Ok(RespData::Array(Some(items))) => { + assert_eq!(items.len(), 1); + match &items[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"hello"), + _ => panic!("Expected SimpleString 'hello'"), + } + } + other => panic!("Expected Array with 1 item, got {:?}", other), + } + + // Second chunk: "+world\r\n" + decoder.push("+world\r\n".into()); + let result = decoder.next(); + assert!(result.is_some(), "Should have second complete array"); + + match result.unwrap() { + Ok(RespData::Array(Some(items))) => { + assert_eq!(items.len(), 1); + match &items[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"world"), + _ => panic!("Expected SimpleString 'world'"), + } + } + other => panic!("Expected Array with 1 item, got {:?}", other), + } +} diff --git a/src/resp/tests/nested_incremental.rs b/src/resp/tests/nested_incremental.rs new file mode 100644 index 00000000..2d2ab694 --- /dev/null +++ b/src/resp/tests/nested_incremental.rs @@ -0,0 +1,230 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use resp::{RespData, RespVersion, new_decoder}; + +#[test] +fn nested_array_incremental() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: outer array with incomplete inner array + decoder.push("*2\r\n*1\r\n".into()); + let result1 = decoder.next(); + assert!(result1.is_none(), "Should need more data for nested array"); + + // Second chunk: complete inner array + second element + decoder.push("+foo\r\n+bar\r\n".into()); + let result2 = decoder.next(); + + match result2 { + Some(Ok(RespData::Array(Some(items)))) => { + assert_eq!(items.len(), 2, "Outer array should have 2 elements"); + match &items[0] { + RespData::Array(Some(inner)) => { + assert_eq!(inner.len(), 1); + match &inner[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"foo"), + _ => panic!("Inner array should contain SimpleString 'foo'"), + } + } + _ => panic!("First element should be an array"), + } + match &items[1] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"bar"), + _ => panic!("Second element should be SimpleString 'bar'"), + } + } + other => panic!("Expected nested array, got {:?}", other), + } +} + +#[test] +fn deeply_nested_arrays() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: deeply nested arrays + decoder.push("*2\r\n*1\r\n*1\r\n".into()); + let result1 = decoder.next(); + assert!( + result1.is_none(), + "Should need more data for deeply nested arrays" + ); + + // Second chunk: complete the deepest level + decoder.push("+deep\r\n".into()); + let result2 = decoder.next(); + assert!(result2.is_none(), "Should still need more data"); + + // Third chunk: complete middle level + decoder.push("+middle\r\n".into()); + let result3 = decoder.next(); + + // The parsing should be complete after the third chunk + match result3 { + Some(Ok(RespData::Array(Some(items)))) => { + assert_eq!(items.len(), 2); + + // First element: [["deep"]] + match &items[0] { + RespData::Array(Some(level1)) => { + assert_eq!(level1.len(), 1); + match &level1[0] { + RespData::Array(Some(level2)) => { + assert_eq!(level2.len(), 1); + match &level2[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"deep"), + _ => panic!("Deepest level should contain 'deep'"), + } + } + _ => panic!("Level 1 should contain an array"), + } + } + _ => panic!("First element should be an array"), + } + + // Second element: "middle" (not "outer" as originally expected) + match &items[1] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"middle"), + _ => panic!("Second element should be 'middle'"), + } + } + other => panic!("Expected deeply nested array, got {:?}", other), + } +} + +#[test] +fn nested_map_incremental() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: outer map with incomplete inner map + decoder.push("%1\r\n+key1\r\n%1\r\n".into()); + let result1 = decoder.next(); + assert!(result1.is_none(), "Should need more data for nested map"); + + // Second chunk: complete inner map + decoder.push("+inner_key\r\n+inner_value\r\n".into()); + let result2 = decoder.next(); + + match result2 { + Some(Ok(RespData::Map(items))) => { + assert_eq!(items.len(), 1); + match &items[0] { + (RespData::SimpleString(key), RespData::Map(inner_map)) => { + assert_eq!(key.as_ref(), b"key1"); + assert_eq!(inner_map.len(), 1); + match &inner_map[0] { + (RespData::SimpleString(ik), RespData::SimpleString(iv)) => { + assert_eq!(ik.as_ref(), b"inner_key"); + assert_eq!(iv.as_ref(), b"inner_value"); + } + _ => panic!("Inner map should contain key-value pair"), + } + } + _ => panic!("Outer map should contain key1 -> inner map"), + } + } + other => panic!("Expected nested map, got {:?}", other), + } +} + +#[test] +fn mixed_nested_collections() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: array containing map containing set + decoder.push("*1\r\n%1\r\n+map_key\r\n~1\r\n".into()); + let result1 = decoder.next(); + assert!( + result1.is_none(), + "Should need more data for mixed nested collections" + ); + + // Second chunk: complete the set + decoder.push("+set_item\r\n".into()); + let result2 = decoder.next(); + + match result2 { + Some(Ok(RespData::Array(Some(items)))) => { + assert_eq!(items.len(), 1); + match &items[0] { + RespData::Map(map_items) => { + assert_eq!(map_items.len(), 1); + match &map_items[0] { + (RespData::SimpleString(key), RespData::Set(set_items)) => { + assert_eq!(key.as_ref(), b"map_key"); + assert_eq!(set_items.len(), 1); + match &set_items[0] { + RespData::SimpleString(si) => { + assert_eq!(si.as_ref(), b"set_item"); + } + _ => panic!("Set should contain 'set_item'"), + } + } + _ => panic!("Map should contain map_key -> set"), + } + } + _ => panic!("Array should contain a map"), + } + } + other => panic!("Expected mixed nested collections, got {:?}", other), + } +} + +#[test] +fn multiple_nested_messages() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // First chunk: two nested messages + decoder.push("*1\r\n*1\r\n+first\r\n*1\r\n".into()); + let result1 = decoder.next(); + assert!(result1.is_some(), "Should have first complete message"); + + // Second chunk: complete second message + decoder.push("+second\r\n".into()); + let result2 = decoder.next(); + assert!(result2.is_some(), "Should have second complete message"); + + // Verify first message + match result1.unwrap() { + Ok(RespData::Array(Some(items))) => { + assert_eq!(items.len(), 1); + match &items[0] { + RespData::Array(Some(inner)) => { + assert_eq!(inner.len(), 1); + match &inner[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"first"), + _ => panic!("First message should contain 'first'"), + } + } + _ => panic!("First message should be nested array"), + } + } + other => panic!("Expected first message to be array, got {:?}", other), + } + + // Verify second message + match result2.unwrap() { + Ok(RespData::Array(Some(items))) => { + assert_eq!(items.len(), 1); + match &items[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"second"), + _ => panic!("Second message should contain 'second'"), + } + } + other => panic!("Expected second message to be array, got {:?}", other), + } +} diff --git a/src/resp/tests/policy_downlevel.rs b/src/resp/tests/policy_downlevel.rs new file mode 100644 index 00000000..db6cd63e --- /dev/null +++ b/src/resp/tests/policy_downlevel.rs @@ -0,0 +1,66 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use resp::{ + BooleanMode, DoubleMode, DownlevelPolicy, MapMode, RespData, RespVersion, + new_encoder_with_policy, +}; + +#[test] +fn boolean_as_simplestring() { + let policy = DownlevelPolicy { + boolean_mode: BooleanMode::SimpleString, + ..Default::default() + }; + let mut enc = new_encoder_with_policy(RespVersion::RESP2, policy); + let bytes = resp::encode_many(&mut *enc, &[ + RespData::Boolean(true), + RespData::Boolean(false), + ]) + .unwrap(); + let s = String::from_utf8(bytes.to_vec()).unwrap(); + assert!(s.contains("+OK\r\n")); + assert!(s.contains("+ERR\r\n")); +} + +#[test] +fn double_as_integer_if_whole() { + let policy = DownlevelPolicy { + double_mode: DoubleMode::IntegerIfWhole, + ..Default::default() + }; + let mut enc = new_encoder_with_policy(RespVersion::RESP1, policy); + let bytes = + resp::encode_many(&mut *enc, &[RespData::Double(2.0), RespData::Double(2.5)]).unwrap(); + let s = String::from_utf8(bytes.to_vec()).unwrap(); + assert!(s.contains(":2\r\n")); + assert!(s.contains("2.5")); +} + +#[test] +fn map_as_array_of_pairs() { + let policy = DownlevelPolicy { + map_mode: MapMode::ArrayOfPairs, + ..Default::default() + }; + let mut enc = new_encoder_with_policy(RespVersion::RESP2, policy); + let data = RespData::Map(vec![(RespData::Boolean(true), RespData::Boolean(false))]); + let bytes = resp::encode_many(&mut *enc, &[data]).unwrap(); + let s = String::from_utf8(bytes.to_vec()).unwrap(); + // *1\r\n*2\r\n:1\r\n:0\r\n (or simple string, depends on boolean_mode default) + assert!(s.starts_with("*1\r\n")); +} diff --git a/src/resp/tests/resp1_basic.rs b/src/resp/tests/resp1_basic.rs new file mode 100644 index 00000000..c31da20a --- /dev/null +++ b/src/resp/tests/resp1_basic.rs @@ -0,0 +1,45 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespData, RespVersion, decode_many, new_decoder}; + +#[test] +fn inline_ping() { + let mut dec = new_decoder(RespVersion::RESP1); + let out = decode_many(&mut *dec, Bytes::from("PING\r\n")); + assert_eq!(out.len(), 1, "Expected exactly one decoded frame"); + // Verify it's an Inline command + match out[0].as_ref().unwrap() { + RespData::Inline(parts) => { + assert_eq!(parts.len(), 1); + assert_eq!(parts[0].as_ref(), b"PING"); + } + other => panic!("Expected Inline command, got {:?}", other), + } +} + +#[test] +fn simple_string_ok() { + let mut dec = new_decoder(RespVersion::RESP1); + let out = decode_many(&mut *dec, Bytes::from("+OK\r\n")); + let v = out[0].as_ref().unwrap(); + match v { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"OK"), + _ => panic!("Expected SimpleString, got {:?}", v), + } +} diff --git a/src/resp/tests/resp2_compat.rs b/src/resp/tests/resp2_compat.rs new file mode 100644 index 00000000..0126f4a4 --- /dev/null +++ b/src/resp/tests/resp2_compat.rs @@ -0,0 +1,55 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespData, RespVersion, decode_many, new_decoder}; + +#[test] +fn parse_simple_string_ok() { + let mut dec = new_decoder(RespVersion::RESP2); + let out = decode_many(&mut *dec, Bytes::from("+OK\r\n")); + assert_eq!(out.len(), 1); + let v = out[0].as_ref().unwrap(); + match v { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"OK"), + other => panic!("Expected SimpleString, got {:?}", other), + } +} + +#[test] +fn parse_integer() { + let mut dec = new_decoder(RespVersion::RESP2); + let out = decode_many(&mut *dec, Bytes::from(":1000\r\n")); + let v = out[0].as_ref().unwrap(); + match v { + RespData::Integer(n) => assert_eq!(*n, 1000), + other => panic!("Expected Integer, got {:?}", other), + } +} + +#[test] +fn parse_bulk_and_array() { + let mut dec = new_decoder(RespVersion::RESP2); + let out = decode_many(&mut *dec, Bytes::from("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n")); + let v = out[0].as_ref().unwrap(); + match v { + RespData::Array(Some(items)) => { + assert_eq!(items.len(), 2); + } + other => panic!("Expected Array, got {:?}", other), + } +} diff --git a/src/resp/tests/resp3_basic.rs b/src/resp/tests/resp3_basic.rs new file mode 100644 index 00000000..d22d90e4 --- /dev/null +++ b/src/resp/tests/resp3_basic.rs @@ -0,0 +1,54 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespData, RespVersion, decode_many, new_decoder, new_encoder}; + +#[test] +fn resp3_null_boolean_double_decode() { + let mut dec = new_decoder(RespVersion::RESP3); + // Use separate inputs for clarity + let out = decode_many(&mut *dec, Bytes::from("_\r\n#t\r\n,1.5\r\n")); + assert_eq!(out.len(), 3, "Expected three decoded frames"); + match out[0].as_ref().unwrap() { + RespData::Null => {} + _ => panic!("expected Null"), + } + match out[1].as_ref().unwrap() { + RespData::Boolean(true) => {} + _ => panic!("expected Boolean(true)"), + } + match out[2].as_ref().unwrap() { + RespData::Double(v) if (*v - 1.5).abs() < f64::EPSILON => {} + _ => panic!("expected Double(1.5)"), + } +} + +#[test] +fn resp3_null_boolean_double_encode() { + let mut enc = new_encoder(RespVersion::RESP3); + let items = [ + RespData::Null, + RespData::Boolean(true), + RespData::Double(1.5), + ]; + let bytes = resp::encode_many(&mut *enc, &items).unwrap(); + let s = String::from_utf8(bytes.to_vec()).unwrap(); + assert!(s.contains("_\r\n")); + assert!(s.contains("#t\r\n")); + assert!(s.contains(",1.5\r\n") || s.contains(",1.5")); +} diff --git a/src/resp/tests/resp3_collections.rs b/src/resp/tests/resp3_collections.rs new file mode 100644 index 00000000..2f528d1a --- /dev/null +++ b/src/resp/tests/resp3_collections.rs @@ -0,0 +1,65 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespData, RespVersion, decode_many, encode_many, new_decoder, new_encoder}; + +#[test] +fn set_roundtrip() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::Set(vec![ + RespData::Boolean(true), + RespData::Null, + RespData::Double(2.5), + ]); + let bytes = encode_many(&mut *enc, &[data]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + match out[0].as_ref().unwrap() { + RespData::Set(items) => assert_eq!(items.len(), 3), + other => panic!("Expected Set, got {:?}", other), + } +} + +#[test] +fn map_roundtrip() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::Map(vec![ + (RespData::Boolean(true), RespData::Double(1.0)), + (RespData::Null, RespData::BulkError(Bytes::from("ERR x"))), + ]); + let bytes = encode_many(&mut *enc, &[data]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + match out[0].as_ref().unwrap() { + RespData::Map(entries) => assert_eq!(entries.len(), 2), + other => panic!("Expected Map, got {:?}", other), + } +} + +#[test] +fn push_roundtrip() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::Push(vec![RespData::Boolean(false), RespData::Double(3.14)]); + let bytes = encode_many(&mut *enc, &[data]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + match out[0].as_ref().unwrap() { + RespData::Push(items) => assert_eq!(items.len(), 2), + other => panic!("Expected Push, got {:?}", other), + } +} diff --git a/src/resp/tests/resp3_collections_comprehensive.rs b/src/resp/tests/resp3_collections_comprehensive.rs new file mode 100644 index 00000000..09467441 --- /dev/null +++ b/src/resp/tests/resp3_collections_comprehensive.rs @@ -0,0 +1,156 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use resp::{RespData, RespVersion, decode_many, encode_many, new_decoder, new_encoder}; + +#[test] +fn map_with_standard_resp_types() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::Map(vec![ + (RespData::SimpleString("key1".into()), RespData::Integer(42)), + ( + RespData::BulkString(Some("key2".into())), + RespData::SimpleString("value2".into()), + ), + ]); + let bytes = encode_many(&mut *enc, &[data]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + + assert_eq!(out.len(), 1, "Expected 1 result"); + match out[0].as_ref().unwrap() { + RespData::Map(entries) => { + assert_eq!(entries.len(), 2, "Expected 2 map entries"); + // Verify first entry + match &entries[0] { + (RespData::SimpleString(k), RespData::Integer(v)) => { + assert_eq!(k.as_ref(), b"key1"); + assert_eq!(*v, 42); + } + _ => panic!("Expected (SimpleString, Integer) for first entry"), + } + // Verify second entry + match &entries[1] { + (RespData::BulkString(Some(k)), RespData::SimpleString(v)) => { + assert_eq!(k.as_ref(), b"key2"); + assert_eq!(v.as_ref(), b"value2"); + } + _ => panic!("Expected (BulkString, SimpleString) for second entry"), + } + } + other => panic!("Expected Map, got {:?}", other), + } +} + +#[test] +fn set_with_standard_resp_types() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::Set(vec![ + RespData::SimpleString("item1".into()), + RespData::Integer(123), + RespData::BulkString(Some("item3".into())), + ]); + let bytes = encode_many(&mut *enc, &[data]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + + assert_eq!(out.len(), 1, "Expected 1 result"); + match out[0].as_ref().unwrap() { + RespData::Set(items) => { + assert_eq!(items.len(), 3, "Expected 3 set items"); + // Check that all expected types are present + let mut found_simple = false; + let mut found_integer = false; + let mut found_bulk = false; + + for item in items { + match item { + RespData::SimpleString(s) if s.as_ref() == b"item1" => found_simple = true, + RespData::Integer(i) if *i == 123 => found_integer = true, + RespData::BulkString(Some(s)) if s.as_ref() == b"item3" => found_bulk = true, + _ => {} + } + } + + assert!(found_simple, "Expected SimpleString 'item1' in set"); + assert!(found_integer, "Expected Integer 123 in set"); + assert!(found_bulk, "Expected BulkString 'item3' in set"); + } + other => panic!("Expected Set, got {:?}", other), + } +} + +#[test] +fn push_with_standard_resp_types() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::Push(vec![ + RespData::SimpleString("PUSH".into()), + RespData::Integer(1), + RespData::BulkString(Some("message".into())), + ]); + let bytes = encode_many(&mut *enc, &[data]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + + assert_eq!(out.len(), 1, "Expected 1 result"); + match out[0].as_ref().unwrap() { + RespData::Push(items) => { + assert_eq!(items.len(), 3, "Expected 3 push items"); + // Verify each item + match &items[0] { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"PUSH"), + _ => panic!("Expected SimpleString 'PUSH' as first item"), + } + match &items[1] { + RespData::Integer(i) => assert_eq!(*i, 1), + _ => panic!("Expected Integer 1 as second item"), + } + match &items[2] { + RespData::BulkString(Some(s)) => assert_eq!(s.as_ref(), b"message"), + _ => panic!("Expected BulkString 'message' as third item"), + } + } + other => panic!("Expected Push, got {:?}", other), + } +} + +#[test] +fn map_with_mixed_types() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::Map(vec![ + ( + RespData::SimpleString("key".into()), + RespData::Boolean(true), + ), + ( + RespData::Integer(1), + RespData::BulkString(Some("value".into())), + ), + (RespData::Null, RespData::Double(3.14)), + ]); + let bytes = encode_many(&mut *enc, &[data]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + + assert_eq!(out.len(), 1, "Expected 1 result"); + match out[0].as_ref().unwrap() { + RespData::Map(entries) => { + assert_eq!(entries.len(), 3, "Expected 3 map entries"); + } + other => panic!("Expected Map, got {:?}", other), + } +} diff --git a/src/resp/tests/resp3_encoder_comprehensive.rs b/src/resp/tests/resp3_encoder_comprehensive.rs new file mode 100644 index 00000000..da576a77 --- /dev/null +++ b/src/resp/tests/resp3_encoder_comprehensive.rs @@ -0,0 +1,121 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespData, RespVersion, new_encoder}; + +#[test] +fn resp3_encoder_all_types() { + let mut enc = new_encoder(RespVersion::RESP3); + + // Test all RESP3 types + let tests = vec![ + (RespData::Null, b"_\r\n" as &[u8]), + (RespData::Boolean(true), b"#t\r\n" as &[u8]), + (RespData::Boolean(false), b"#f\r\n" as &[u8]), + (RespData::Double(3.14), b",3.14\r\n" as &[u8]), + (RespData::Double(f64::INFINITY), b",inf\r\n" as &[u8]), + (RespData::Double(f64::NEG_INFINITY), b",-inf\r\n" as &[u8]), + (RespData::Double(f64::NAN), b",nan\r\n" as &[u8]), + ( + RespData::BulkError(Bytes::from("ERR something")), + b"!13\r\nERR something\r\n" as &[u8], + ), + ( + RespData::VerbatimString { + format: *b"txt", + data: Bytes::from("hello"), + }, + b"=9\r\ntxt:hello\r\n" as &[u8], + ), + ( + RespData::BigNumber("12345678901234567890".into()), + b"(12345678901234567890\r\n" as &[u8], + ), + ]; + + for (data, expected) in tests { + let result = enc.encode_one(&data); + assert!(result.is_ok(), "Failed to encode {:?}", data); + let bytes = result.unwrap(); + assert_eq!(bytes.as_ref(), expected, "Wrong encoding for {:?}", data); + } +} + +#[test] +fn resp3_encoder_standard_resp_types() { + let mut enc = new_encoder(RespVersion::RESP3); + + // Test standard RESP types + let tests = vec![ + (RespData::SimpleString("OK".into()), b"+OK\r\n" as &[u8]), + ( + RespData::Error("ERR something".into()), + b"-ERR something\r\n" as &[u8], + ), + (RespData::Integer(42), b":42\r\n" as &[u8]), + ( + RespData::BulkString(Some("hello".into())), + b"$5\r\nhello\r\n" as &[u8], + ), + (RespData::BulkString(None), b"$-1\r\n" as &[u8]), + ( + RespData::Array(Some(vec![RespData::SimpleString("OK".into())])), + b"*1\r\n+OK\r\n" as &[u8], + ), + (RespData::Array(None), b"*-1\r\n" as &[u8]), + (RespData::Inline(vec!["PING".into()]), b"PING\r\n" as &[u8]), + ]; + + for (data, expected) in tests { + let result = enc.encode_one(&data); + assert!(result.is_ok(), "Failed to encode {:?}", data); + let bytes = result.unwrap(); + assert_eq!(bytes.as_ref(), expected, "Wrong encoding for {:?}", data); + } +} + +#[test] +fn resp3_encoder_collections() { + let mut enc = new_encoder(RespVersion::RESP3); + + // Test collection types + let set_data = RespData::Set(vec![ + RespData::Boolean(true), + RespData::Null, + RespData::Double(2.5), + ]); + let set_result = enc.encode_one(&set_data); + assert!(set_result.is_ok()); + let set_bytes = set_result.unwrap(); + assert!(set_bytes.starts_with(b"~3\r\n")); + + let map_data = RespData::Map(vec![ + (RespData::Boolean(true), RespData::Double(1.0)), + (RespData::Null, RespData::BulkError(Bytes::from("ERR x"))), + ]); + let map_result = enc.encode_one(&map_data); + assert!(map_result.is_ok()); + let map_bytes = map_result.unwrap(); + assert!(map_bytes.starts_with(b"%2\r\n")); + + let push_data = RespData::Push(vec![RespData::Boolean(false), RespData::Double(3.14)]); + let push_result = enc.encode_one(&push_data); + assert!(push_result.is_ok()); + let push_bytes = push_result.unwrap(); + assert!(push_bytes.starts_with(b">2\r\n")); +} diff --git a/src/resp/tests/resp3_more.rs b/src/resp/tests/resp3_more.rs new file mode 100644 index 00000000..7021c001 --- /dev/null +++ b/src/resp/tests/resp3_more.rs @@ -0,0 +1,64 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespData, RespVersion, decode_many, encode_many, new_decoder, new_encoder}; + +#[test] +fn bulk_error_roundtrip() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::BulkError(Bytes::from("ERR something")); + let bytes = encode_many(&mut *enc, &[data.clone()]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + match out[0].as_ref().unwrap() { + RespData::BulkError(s) => assert_eq!(s.as_ref(), b"ERR something"), + other => panic!("Expected BulkError, got {:?}", other), + } +} + +#[test] +fn verbatim_string_roundtrip() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::VerbatimString { + format: *b"txt", + data: Bytes::from("hello"), + }; + let bytes = encode_many(&mut *enc, &[data]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + match out[0].as_ref().unwrap() { + RespData::VerbatimString { format, data } => { + assert_eq!(format, b"txt"); + assert_eq!(data.as_ref(), b"hello"); + } + other => panic!("Expected VerbatimString, got {:?}", other), + } +} + +#[test] +fn bignumber_roundtrip() { + let mut enc = new_encoder(RespVersion::RESP3); + let data = RespData::BigNumber("12345678901234567890".into()); + let bytes = encode_many(&mut *enc, &[data.clone()]).unwrap(); + let mut dec = new_decoder(RespVersion::RESP3); + let out = decode_many(&mut *dec, bytes); + match out[0].as_ref().unwrap() { + RespData::BigNumber(s) => assert_eq!(s, "12345678901234567890"), + other => panic!("Expected BigNumber, got {:?}", other), + } +} diff --git a/src/resp/tests/resp3_scaffold.rs b/src/resp/tests/resp3_scaffold.rs new file mode 100644 index 00000000..6db93274 --- /dev/null +++ b/src/resp/tests/resp3_scaffold.rs @@ -0,0 +1,40 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespData, RespVersion, new_decoder}; + +#[test] +fn resp3_boolean_and_null() { + let mut dec = new_decoder(RespVersion::RESP3); + dec.push(Bytes::from("#t\r\n")); + dec.push(Bytes::from("_\r\n")); + + // Verify Boolean(true) parsing + let result1 = dec.next().unwrap().unwrap(); + match result1 { + RespData::Boolean(true) => {} + _ => panic!("expected Boolean(true), got {:?}", result1), + } + + // Verify Null parsing + let result2 = dec.next().unwrap().unwrap(); + match result2 { + RespData::Null => {} + _ => panic!("expected Null, got {:?}", result2), + } +} diff --git a/src/resp/tests/resp3_standard_types.rs b/src/resp/tests/resp3_standard_types.rs new file mode 100644 index 00000000..14eb7f4d --- /dev/null +++ b/src/resp/tests/resp3_standard_types.rs @@ -0,0 +1,131 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use resp::{RespData, RespVersion, decode_many, new_decoder, new_encoder}; + +#[test] +fn resp3_decodes_standard_resp_types() { + // Test simple string + let mut dec1 = new_decoder(RespVersion::RESP3); + let out1 = decode_many(&mut *dec1, Bytes::from("+OK\r\n")); + assert_eq!(out1.len(), 1); + match out1[0].as_ref().unwrap() { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"OK"), + other => panic!("Expected SimpleString, got {:?}", other), + } + + // Test integer + let mut dec2 = new_decoder(RespVersion::RESP3); + let out2 = decode_many(&mut *dec2, Bytes::from(":42\r\n")); + assert_eq!(out2.len(), 1); + match out2[0].as_ref().unwrap() { + RespData::Integer(n) => assert_eq!(*n, 42), + other => panic!("Expected Integer, got {:?}", other), + } + + // Test bulk string + let mut dec3 = new_decoder(RespVersion::RESP3); + let out3 = decode_many(&mut *dec3, Bytes::from("$5\r\nhello\r\n")); + assert_eq!(out3.len(), 1); + match out3[0].as_ref().unwrap() { + RespData::BulkString(Some(s)) => assert_eq!(s.as_ref(), b"hello"), + other => panic!("Expected BulkString, got {:?}", other), + } + + // Test array + let mut dec4 = new_decoder(RespVersion::RESP3); + let out4 = decode_many(&mut *dec4, Bytes::from("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n")); + assert_eq!(out4.len(), 1); + match out4[0].as_ref().unwrap() { + RespData::Array(Some(items)) => { + assert_eq!(items.len(), 2); + match &items[0] { + RespData::BulkString(Some(s)) => assert_eq!(s.as_ref(), b"foo"), + other => panic!("Expected BulkString 'foo', got {:?}", other), + } + match &items[1] { + RespData::BulkString(Some(s)) => assert_eq!(s.as_ref(), b"bar"), + other => panic!("Expected BulkString 'bar', got {:?}", other), + } + } + other => panic!("Expected Array, got {:?}", other), + } +} + +#[test] +fn resp3_encodes_standard_resp_types() { + let mut enc = new_encoder(RespVersion::RESP3); + + // Test simple string + let data1 = RespData::SimpleString("OK".into()); + let bytes1 = enc.encode_one(&data1).unwrap(); + assert_eq!(bytes1.as_ref(), b"+OK\r\n"); + + // Test integer + let data2 = RespData::Integer(42); + let bytes2 = enc.encode_one(&data2).unwrap(); + assert_eq!(bytes2.as_ref(), b":42\r\n"); + + // Test bulk string + let data3 = RespData::BulkString(Some("hello".into())); + let bytes3 = enc.encode_one(&data3).unwrap(); + assert_eq!(bytes3.as_ref(), b"$5\r\nhello\r\n"); + + // Test array + let data4 = RespData::Array(Some(vec![ + RespData::BulkString(Some("foo".into())), + RespData::BulkString(Some("bar".into())), + ])); + let bytes4 = enc.encode_one(&data4).unwrap(); + assert_eq!(bytes4.as_ref(), b"*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"); +} + +#[test] +fn resp3_mixed_resp2_and_resp3_types() { + let mut dec = new_decoder(RespVersion::RESP3); + + // Mix RESP2 and RESP3 types in one input + let input = Bytes::from("+OK\r\n#t\r\n:42\r\n_\r\n"); + let out = decode_many(&mut *dec, input); + + assert_eq!(out.len(), 4); + + // Simple string + match out[0].as_ref().unwrap() { + RespData::SimpleString(s) => assert_eq!(s.as_ref(), b"OK"), + other => panic!("Expected SimpleString, got {:?}", other), + } + + // Boolean + match out[1].as_ref().unwrap() { + RespData::Boolean(true) => {} + other => panic!("Expected Boolean(true), got {:?}", other), + } + + // Integer + match out[2].as_ref().unwrap() { + RespData::Integer(n) => assert_eq!(*n, 42), + other => panic!("Expected Integer, got {:?}", other), + } + + // Null + match out[3].as_ref().unwrap() { + RespData::Null => {} + other => panic!("Expected Null, got {:?}", other), + } +} diff --git a/src/resp/tests/security_limits.rs b/src/resp/tests/security_limits.rs new file mode 100644 index 00000000..e82a6216 --- /dev/null +++ b/src/resp/tests/security_limits.rs @@ -0,0 +1,212 @@ +// Copyright (c) 2024-present, arana-db Community. All rights reserved. +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use resp::{RespVersion, new_decoder}; + +#[test] +fn bulk_string_length_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test bulk string exceeding 512MB limit + let oversized_len = 512 * 1024 * 1024 + 1; // 512MB + 1 byte + let oversized_message = format!("${}\r\n", oversized_len); + + decoder.push(oversized_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized bulk string" + ); +} + +#[test] +fn array_length_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test array exceeding 1M elements limit + let oversized_len = 1024 * 1024 + 1; // 1M + 1 elements + let oversized_message = format!("*{}\r\n", oversized_len); + + decoder.push(oversized_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized array" + ); +} + +#[test] +fn map_pairs_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test map exceeding 1M pairs limit + let oversized_pairs = 1024 * 1024 + 1; // 1M + 1 pairs + let oversized_message = format!("%{}\r\n", oversized_pairs); + + decoder.push(oversized_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized map" + ); +} + +#[test] +fn set_length_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test set exceeding 1M elements limit + let oversized_len = 1024 * 1024 + 1; // 1M + 1 elements + let oversized_message = format!("~{}\r\n", oversized_len); + + decoder.push(oversized_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized set" + ); +} + +#[test] +fn push_length_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test push exceeding 1M elements limit + let oversized_len = 1024 * 1024 + 1; // 1M + 1 elements + let oversized_message = format!(">{}\r\n", oversized_len); + + decoder.push(oversized_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized push" + ); +} + +#[test] +fn bulk_error_length_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test bulk error exceeding 512MB limit + let oversized_len = 512 * 1024 * 1024 + 1; // 512MB + 1 byte + let oversized_message = format!("!{}\r\n", oversized_len); + + decoder.push(oversized_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized bulk error" + ); +} + +#[test] +fn verbatim_string_length_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test verbatim string exceeding 512MB limit + let oversized_len = 512 * 1024 * 1024 + 1; // 512MB + 1 byte + let oversized_message = format!("=txt:{}\r\n", oversized_len); + + decoder.push(oversized_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized verbatim string" + ); +} + +#[test] +fn big_number_length_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test big number exceeding 16MB limit + let oversized_len = 16 * 1024 * 1024 + 1; // 16MB + 1 byte + let oversized_digits = "1".repeat(oversized_len); + let oversized_message = format!("({}\r\n", oversized_digits); + + decoder.push(oversized_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized big number" + ); +} + +#[test] +fn inline_command_length_limit() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test inline command exceeding 4KB limit + let oversized_len = 4 * 1024 + 1; // 4KB + 1 byte + let oversized_command = "a".repeat(oversized_len) + "\r\n"; + + decoder.push(oversized_command.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for oversized inline command" + ); +} + +#[test] +fn inline_command_invalid_prefix() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test inline command with non-printable character + let invalid_command = "\x01invalid\r\n"; + + decoder.push(invalid_command.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should return an error"); + assert!( + result.unwrap().is_err(), + "Should return parse error for invalid inline command prefix" + ); +} + +#[test] +fn within_limits_should_work() { + let mut decoder = new_decoder(RespVersion::RESP3); + + // Test that values within limits work correctly + let normal_message = "*2\r\n+hello\r\n+world\r\n"; + + decoder.push(normal_message.into()); + let result = decoder.next(); + + assert!(result.is_some(), "Should parse successfully"); + assert!(result.unwrap().is_ok(), "Should parse without error"); +}