From 9846c69f91e732b493dcccf92f4c163bfc711b7d Mon Sep 17 00:00:00 2001 From: boxdot Date: Thu, 13 Feb 2025 11:05:18 +0100 Subject: [PATCH 1/5] tls_codec: add variable-length integer type TlsVarInt As defined in #[rfc9000]. Also use this type (with an internal thin wrapper `ContentLength`) when encoding/deconding the content length of vectors. [rfc9000]: https://www.rfc-editor.org/rfc/rfc9000#name-variable-length-integer-enc --- tls_codec/README.md | 2 +- tls_codec/src/lib.rs | 3 + tls_codec/src/quic_vec.rs | 247 ++++++++++------------------------ tls_codec/src/varint.rs | 274 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 352 insertions(+), 174 deletions(-) create mode 100644 tls_codec/src/varint.rs diff --git a/tls_codec/README.md b/tls_codec/README.md index be0b94f89..d8af87e0f 100644 --- a/tls_codec/README.md +++ b/tls_codec/README.md @@ -19,7 +19,7 @@ derived. The crate also provides the following data structures that implement TLS serialization/deserialization -- `u8`, `u16`, `u32`, `u64` +- `u8`, `u16`, `u32`, `u64`, `TlsVarInt` - `TlsVecU8`, `TlsVecU16`, `TlsVecU32` - `SecretTlsVecU8`, `SecretTlsVecU16`, `SecretTlsVecU32` The same as the `TlsVec*` versions but it implements zeroize, requiring diff --git a/tls_codec/src/lib.rs b/tls_codec/src/lib.rs index e662691d7..51016812f 100644 --- a/tls_codec/src/lib.rs +++ b/tls_codec/src/lib.rs @@ -39,6 +39,7 @@ mod arrays; mod primitives; mod quic_vec; mod tls_vec; +mod varint; pub use tls_vec::{ SecretTlsVecU16, SecretTlsVecU24, SecretTlsVecU32, SecretTlsVecU8, TlsByteSliceU16, @@ -59,6 +60,8 @@ pub use tls_codec_derive::{ #[cfg(feature = "conditional_deserialization")] pub use tls_codec_derive::conditionally_deserializable; +pub use varint::TlsVarInt; + /// Errors that are thrown by this crate. #[derive(Debug, Eq, PartialEq, Clone)] pub enum Error { diff --git a/tls_codec/src/quic_vec.rs b/tls_codec/src/quic_vec.rs index eb77fb479..14e26f3f4 100644 --- a/tls_codec/src/quic_vec.rs +++ b/tls_codec/src/quic_vec.rs @@ -24,115 +24,47 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize}; use crate::{DeserializeBytes, Error, SerializeBytes, Size}; -#[cfg(not(feature = "mls"))] -const MAX_LEN: u64 = (1 << 62) - 1; -#[cfg(not(feature = "mls"))] -const MAX_LEN_LEN_LOG: usize = 3; #[cfg(feature = "mls")] -const MAX_LEN: u64 = (1 << 30) - 1; -#[cfg(feature = "mls")] -const MAX_LEN_LEN_LOG: usize = 2; - -#[inline(always)] -fn check_min_length(length: usize, len_len: usize) -> Result<(), Error> { - if cfg!(feature = "mls") { - // ensure that len_len is minimal for the given length - let min_len_len = length_encoding_bytes(length as u64)?; - if min_len_len != len_len { - return Err(Error::InvalidVectorLength); - } - }; - Ok(()) -} +const MAX_MLS_LEN: u64 = (1 << 30) - 1; -#[inline(always)] -fn calculate_length(len_len_byte: u8) -> Result<(usize, usize), Error> { - let length: usize = (len_len_byte & 0x3F).into(); - let len_len_log = (len_len_byte >> 6).into(); - if !cfg!(fuzzing) { - debug_assert!(len_len_log <= MAX_LEN_LEN_LOG); - } - if len_len_log > MAX_LEN_LEN_LOG { - return Err(Error::InvalidVectorLength); - } - let len_len = match len_len_log { - 0 => 1, - 1 => 2, - 2 => 4, - 3 => 8, - _ => unreachable!(), - }; - Ok((length, len_len)) -} - -#[inline(always)] -fn read_variable_length_bytes(bytes: &[u8]) -> Result<((usize, usize), &[u8]), Error> { - // The length is encoded in the first two bits of the first byte. +/// Thin wrapper around [`TlsVarInt`] representing the length of encoded vector content in bytes. +/// +/// When `mls` feature is enabled, the maximum length is limited to 30-bit. Otherwise, this type is +/// no-op. +struct ContentLength(super::TlsVarInt); - let (len_len_byte, mut remainder) = u8::tls_deserialize_bytes(bytes)?; +impl ContentLength { + #[cfg(not(feature = "mls"))] + #[allow(dead_code)] // used in arbitrary + const MAX: u64 = crate::TlsVarInt::MAX; - let (mut length, len_len) = calculate_length(len_len_byte)?; + #[cfg(feature = "mls")] + const MAX: u64 = MAX_MLS_LEN; - for _ in 1..len_len { - let (next, next_remainder) = u8::tls_deserialize_bytes(remainder)?; - remainder = next_remainder; - length = (length << 8) + usize::from(next); + fn new(value: super::TlsVarInt) -> Result { + #[cfg(feature = "mls")] + if Self::MAX < value.value() { + return Err(Error::InvalidVectorLength); + } + Ok(Self(value)) } - check_min_length(length, len_len)?; - - Ok(((length, len_len), remainder)) + fn from_usize(value: usize) -> Result { + Self::new(super::TlsVarInt::try_new(value.try_into()?)?) + } } -#[inline(always)] -fn length_encoding_bytes(length: u64) -> Result { - if !cfg!(fuzzing) { - debug_assert!(length <= MAX_LEN); - } - if length > MAX_LEN { - return Err(Error::InvalidVectorLength); +impl Size for ContentLength { + fn tls_serialized_len(&self) -> usize { + self.0.tls_serialized_len() } - - Ok(if length <= 0x3f { - 1 - } else if length <= 0x3fff { - 2 - } else if length <= 0x3fff_ffff { - 4 - } else { - 8 - }) } -#[inline(always)] -pub fn write_variable_length(content_length: usize) -> Result, Error> { - let len_len = length_encoding_bytes(content_length.try_into()?)?; - if !cfg!(fuzzing) { - debug_assert!(len_len <= 8, "Invalid vector len_len {len_len}"); - } - if len_len > 8 { - return Err(Error::LibraryError); - } - let mut length_bytes = vec![0u8; len_len]; - match len_len { - 1 => length_bytes[0] = 0x00, - 2 => length_bytes[0] = 0x40, - 4 => length_bytes[0] = 0x80, - 8 => length_bytes[0] = 0xc0, - _ => { - if !cfg!(fuzzing) { - debug_assert!(false, "Invalid vector len_len {len_len}"); - } - return Err(Error::InvalidVectorLength); - } - } - let mut len = content_length; - for b in length_bytes.iter_mut().rev() { - *b |= (len & 0xFF) as u8; - len >>= 8; +impl DeserializeBytes for ContentLength { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { + let (value, remainder) = super::TlsVarInt::tls_deserialize_bytes(bytes)?; + Ok((Self(value), remainder)) } - - Ok(length_bytes) } impl Size for Vec { @@ -152,7 +84,9 @@ impl Size for &Vec { impl DeserializeBytes for Vec { #[inline(always)] fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { - let ((length, len_len), mut remainder) = read_variable_length_bytes(bytes)?; + let (length, mut remainder) = ContentLength::tls_deserialize_bytes(bytes)?; + let len_len = length.0.bytes_len(); + let length: usize = length.0.value().try_into()?; if length == 0 { // An empty vector. @@ -178,11 +112,12 @@ impl SerializeBytes for &[T] { // This requires more computations but the other option would be to buffer // the entire content, which can end up requiring a lot of memory. let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len()); - let mut length = write_variable_length(content_length)?; - let len_len = length.len(); + let length = ContentLength::from_usize(content_length)?; + let len_len = length.0.bytes_len(); let mut out = Vec::with_capacity(content_length + len_len); - out.append(&mut length); + out.resize(len_len, 0); + length.0.write_bytes(&mut out)?; // Serialize the elements for e in self.iter() { @@ -214,11 +149,13 @@ impl Size for &[T] { #[inline(always)] fn tls_serialized_len(&self) -> usize { let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len()); - let len_len = length_encoding_bytes(content_length as u64).unwrap_or({ - // We can't do anything about the error unless we change the trait. - // Let's say there's no content for now. - 0 - }); + let len_len = ContentLength::from_usize(content_length) + .map(|content_length| content_length.0.bytes_len()) + .unwrap_or({ + // We can't do anything about the error unless we change the trait. + // Let's say there's no content for now. + 0 + }); content_length + len_len } } @@ -327,10 +264,12 @@ impl From for Vec { #[inline(always)] fn tls_serialize_bytes_len(bytes: &[u8]) -> usize { let content_length = bytes.len(); - let len_len = length_encoding_bytes(content_length as u64).unwrap_or({ - // We can't do anything about the error. Let's say there's no content. - 0 - }); + let len_len = ContentLength::from_usize(content_length) + .map(|content_length| content_length.0.bytes_len()) + .unwrap_or({ + // We can't do anything about the error. Let's say there's no content. + 0 + }); content_length + len_len } @@ -344,22 +283,13 @@ impl Size for VLBytes { impl DeserializeBytes for VLBytes { #[inline(always)] fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { - let ((length, _), remainder) = read_variable_length_bytes(bytes)?; + let (length, remainder) = ContentLength::tls_deserialize_bytes(bytes)?; + let length: usize = length.0.value().try_into()?; + if length == 0 { return Ok((Self::new(vec![]), remainder)); } - if !cfg!(fuzzing) { - debug_assert!( - length <= MAX_LEN as usize, - "Trying to allocate {length} bytes. Only {MAX_LEN} allowed.", - ); - } - if length > MAX_LEN as usize { - return Err(Error::DecodingError(format!( - "Trying to allocate {length} bytes. Only {MAX_LEN} allowed.", - ))); - } match remainder.get(..length).ok_or(Error::EndOfStream) { Ok(vec) => Ok((Self { vec: vec.to_vec() }, &remainder[length..])), Err(_e) => { @@ -422,6 +352,19 @@ pub mod rw { use super::*; use crate::{Deserialize, Serialize}; + impl Deserialize for ContentLength { + fn tls_deserialize(bytes: &mut R) -> Result { + ContentLength::new(crate::TlsVarInt::tls_deserialize(bytes)?) + } + } + + impl Serialize for ContentLength { + #[inline(always)] + fn tls_serialize(&self, writer: &mut W) -> Result { + self.0.tls_serialize(writer) + } + } + /// Read the length of a variable-length vector. /// /// This function assumes that the reader is at the start of a variable length @@ -430,26 +373,9 @@ pub mod rw { /// The length and number of bytes read are returned. #[inline] pub fn read_length(bytes: &mut R) -> Result<(usize, usize), Error> { - // The length is encoded in the first two bits of the first byte. - let mut len_len_byte = [0u8; 1]; - if bytes.read(&mut len_len_byte)? == 0 { - // There must be at least one byte for the length. - // If we don't even have a length byte, this is not a valid - // variable-length encoded vector. - return Err(Error::InvalidVectorLength); - } - let len_len_byte = len_len_byte[0]; - - let (mut length, len_len) = calculate_length(len_len_byte)?; - - for _ in 1..len_len { - let mut next = [0u8; 1]; - bytes.read_exact(&mut next)?; - length = (length << 8) + usize::from(next[0]); - } - - check_min_length(length, len_len)?; - + let length = ContentLength::tls_deserialize(bytes)?; + let len_len = length.0.bytes_len(); + let length: usize = length.0.value().try_into()?; Ok((length, len_len)) } @@ -479,10 +405,7 @@ pub mod rw { writer: &mut W, content_length: usize, ) -> Result { - let buf = super::write_variable_length(content_length)?; - let buf_len = buf.len(); - writer.write_all(&buf)?; - Ok(buf_len) + ContentLength::from_usize(content_length)?.tls_serialize(writer) } impl Serialize for Vec { @@ -538,19 +461,7 @@ mod rw_bytes { // large and write it out. let content_length = bytes.len(); - if !cfg!(fuzzing) { - debug_assert!( - content_length as u64 <= MAX_LEN, - "Vector can't be encoded. It's too large. {content_length} >= {MAX_LEN}", - ); - } - if content_length as u64 > MAX_LEN { - return Err(Error::InvalidVectorLength); - } - - let length_bytes = write_variable_length(content_length)?; - let len_len = length_bytes.len(); - writer.write_all(&length_bytes)?; + let len_len = ContentLength::from_usize(content_length)?.tls_serialize(writer)?; // Now serialize the elements writer.write_all(bytes)?; @@ -574,24 +485,14 @@ mod rw_bytes { impl Deserialize for VLBytes { fn tls_deserialize(bytes: &mut R) -> Result { - let (length, _) = rw::read_length(bytes)?; - if length == 0 { + let length = ContentLength::tls_deserialize(bytes)?; + + if length.0.value() == 0 { return Ok(Self::new(vec![])); } - if !cfg!(fuzzing) { - debug_assert!( - length <= MAX_LEN as usize, - "Trying to allocate {length} bytes. Only {MAX_LEN} allowed.", - ); - } - if length > MAX_LEN as usize { - return Err(Error::DecodingError(format!( - "Trying to allocate {length} bytes. Only {MAX_LEN} allowed.", - ))); - } let mut result = Self { - vec: vec![0u8; length], + vec: vec![0u8; length.0.value().try_into()?], }; bytes.read_exact(result.vec.as_mut_slice())?; Ok(result) @@ -682,7 +583,7 @@ impl<'a> Arbitrary<'a> for VLBytes { // We generate an arbitrary `Vec` ... let mut vec = Vec::arbitrary(u)?; // ... and truncate it to `MAX_LEN`. - vec.truncate(MAX_LEN as usize); + vec.truncate(ContentLength::MAX as usize); // We probably won't exceed `MAX_LEN` in practice, e.g., during fuzzing, // but better make sure that we generate valid instances. diff --git a/tls_codec/src/varint.rs b/tls_codec/src/varint.rs new file mode 100644 index 000000000..573dfbcda --- /dev/null +++ b/tls_codec/src/varint.rs @@ -0,0 +1,274 @@ +use crate::{Deserialize, DeserializeBytes, Error, Serialize, Size}; + +/// Variable-length encoded unsigned integer as defined in [RFC 9000]. +/// +/// [RFC 9000]: https://www.rfc-editor.org/rfc/rfc9000#name-variable-length-integer-enc +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct TlsVarInt(u64); + +impl TlsVarInt { + /// The largest value that can be represented by this type. + pub const MAX: u64 = (1 << 62) - 1; + const MAX_LOG: usize = 3; + + /// Wraps an unsinged integer as variable-length int. + /// + /// Returns `None` if the value is larger than [`Self::MAX`]. + pub const fn new(value: u64) -> Option { + if Self::MAX < value { + None + } else { + Some(Self(value)) + } + } + + pub(crate) fn try_new(value: u64) -> Result { + Self::new(value).ok_or(Error::InvalidVectorLength) + } + + /// Returns the value of this variable-length int. + pub const fn value(&self) -> u64 { + self.0 + } + + /// Returns the number of bytes required to encode this variable-length int. + pub(crate) const fn bytes_len(&self) -> usize { + let value = self.0; + if !cfg!(fuzzing) { + debug_assert!(value <= Self::MAX); + } + if value <= 0x3f { + 1 + } else if value <= 0x3fff { + 2 + } else if value <= 0x3fff_ffff { + 4 + } else { + 8 + } + } + + /// Writes the bytes of this variable-length at the beginning of the buffer. + /// + /// The buffer must be at least of the length returned by [`Self::bytes_len`]. + pub(crate) fn write_bytes(&self, buf: &mut [u8]) -> Result { + let len = self.bytes_len(); + if !cfg!(fuzzing) { + debug_assert!(len <= 8, "Invalid varint len {len}"); + } + if len > 8 { + return Err(Error::LibraryError); + } + + if buf.len() < len { + return Err(Error::InvalidVectorLength); + } + let bytes = &mut buf[..len]; + + match len { + 1 => bytes[0] = 0x00, + 2 => bytes[0] = 0x40, + 4 => bytes[0] = 0x80, + 8 => bytes[0] = 0xc0, + _ => { + if !cfg!(fuzzing) { + debug_assert!(false, "Invalid varint len {len}"); + } + return Err(Error::InvalidVectorLength); + } + } + let mut value = self.0; + for b in bytes.iter_mut().rev() { + *b |= (value & 0xFF) as u8; + value >>= 8; + } + + Ok(len) + } +} + +impl TryFrom for TlsVarInt { + type Error = Error; + + #[inline] + fn try_from(value: u64) -> Result { + Self::try_new(value) + } +} + +impl From for u64 { + #[inline] + fn from(value: TlsVarInt) -> Self { + value.0 + } +} + +#[inline(always)] +fn check_min_len(value: u64, len: usize) -> Result<(), Error> { + if cfg!(feature = "mls") { + // ensure that `len` is minimal for the given `value` + let min_len = TlsVarInt::try_new(value)?.bytes_len(); + if min_len != len { + return Err(Error::InvalidVectorLength); + } + }; + Ok(()) +} + +impl Deserialize for TlsVarInt { + #[cfg(feature = "std")] + #[inline] + fn tls_deserialize(bytes: &mut R) -> Result { + let mut len_byte = [0u8; 1]; + if bytes.read(&mut len_byte)? == 0 { + return Err(Error::EndOfStream); + }; + let len_byte = len_byte[0]; + + let (value, len) = calculate_value(len_byte)?; + let mut value: u64 = value.try_into().map_err(|_| Error::InvalidInput)?; + + for _ in 1..len { + let mut next = [0u8; 1]; + bytes.read_exact(&mut next)?; + value = (value << 8) + u64::from(next[0]); + } + + check_min_len(value, len)?; + + Ok(TlsVarInt(value)) + } +} + +impl DeserializeBytes for TlsVarInt { + #[cfg(feature = "std")] + #[inline] + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> + where + Self: Sized, + { + let (len_byte, mut remainder) = u8::tls_deserialize_bytes(bytes)?; + + let (value, value_len) = calculate_value(len_byte)?; + let mut value: u64 = value.try_into().map_err(|_| Error::InvalidInput)?; + + for _ in 1..value_len { + let (next, next_remainder) = u8::tls_deserialize_bytes(remainder)?; + remainder = next_remainder; + value = (value << 8) + u64::from(next); + } + + Ok((TlsVarInt(value), remainder)) + } +} + +impl Serialize for TlsVarInt { + #[cfg(feature = "std")] + #[inline] + fn tls_serialize(&self, writer: &mut W) -> Result { + let mut bytes = [0u8; 8]; + let len = self.write_bytes(&mut bytes)?; + writer.write_all(&bytes[..len])?; + Ok(len) + } +} + +impl Size for TlsVarInt { + #[inline] + fn tls_serialized_len(&self) -> usize { + self.bytes_len() + } +} + +/// Calculates the value and the length from the first byte. +#[inline(always)] +pub(crate) fn calculate_value(byte: u8) -> Result<(usize, usize), Error> { + let value: usize = (byte & 0x3F).into(); + let len_log = (byte >> 6).into(); + if !cfg!(fuzzing) { + debug_assert!(len_log <= TlsVarInt::MAX_LOG); + } + if len_log > TlsVarInt::MAX_LOG { + return Err(Error::InvalidVectorLength); + } + let len = match len_log { + 0 => 1, + 1 => 2, + 2 => 4, + 3 => 8, + _ => unreachable!(), + }; + Ok((value, len)) +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + use std::vec::Vec; + + use super::*; + + // (value, var length, encoded bytes) + const TESTS: [(u64, usize, &[u8]); 5] = [ + (37, 1, &[0x25]), + (15_293, 2, &[0x7b, 0xbd]), + (494_878_333, 4, &[0x9d, 0x7f, 0x3e, 0x7d]), + ( + 151_288_809_941_952_652, + 8, + &[0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c], + ), + ( + TlsVarInt::MAX, + 8, + &[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], + ), + ]; + + #[test] + fn test_size() { + for (value, len, _) in TESTS { + assert_eq!( + TlsVarInt::try_from(value) + .expect("value too large") + .tls_serialized_len(), + len + ); + } + } + + #[test] + fn test_tls_serde() { + for (value, len, bytes) in TESTS { + let mut buf = Vec::new(); + let written = TlsVarInt::try_from(value) + .expect("value too large") + .tls_serialize(&mut buf) + .expect("tls serialize failed"); + assert_eq!(written, len, "{value}"); + assert_eq!(buf.len(), len, "{value}"); + assert_eq!(&buf[..], bytes, "{value}"); + + let out = TlsVarInt::tls_deserialize(&mut Cursor::new(bytes)) + .expect("tls deserialize failed"); + assert_eq!(out, TlsVarInt::try_from(value).expect("value too large")); + + let (out, remainder) = + TlsVarInt::tls_deserialize_bytes(bytes).expect("tls deserialize bytes failed"); + assert_eq!(remainder.len(), 0); + assert_eq!(out, TlsVarInt::try_from(value).expect("value too large")); + } + } + + #[test] + fn test_non_canonical_encoding() { + let out = TlsVarInt::tls_deserialize(&mut Cursor::new(&[0x40, 0x25])) + .expect("tls deserialize failed"); + assert_eq!(out, TlsVarInt(37)); + + let (out, remaining) = + TlsVarInt::tls_deserialize_bytes(&[0x40, 0x25]).expect("tls deserialize bytes failed"); + assert_eq!(remaining.len(), 0); + assert_eq!(out, TlsVarInt(37)); + } +} From 17b19b68a1cb21cee8782468735c925f084d5f8b Mon Sep 17 00:00:00 2001 From: boxdot Date: Thu, 13 Feb 2025 11:49:39 +0100 Subject: [PATCH 2/5] remove std feature from DeserializeBytes impl for TlsVarInt --- tls_codec/src/varint.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tls_codec/src/varint.rs b/tls_codec/src/varint.rs index 573dfbcda..a300468c5 100644 --- a/tls_codec/src/varint.rs +++ b/tls_codec/src/varint.rs @@ -141,7 +141,6 @@ impl Deserialize for TlsVarInt { } impl DeserializeBytes for TlsVarInt { - #[cfg(feature = "std")] #[inline] fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> where @@ -149,15 +148,17 @@ impl DeserializeBytes for TlsVarInt { { let (len_byte, mut remainder) = u8::tls_deserialize_bytes(bytes)?; - let (value, value_len) = calculate_value(len_byte)?; + let (value, len) = calculate_value(len_byte)?; let mut value: u64 = value.try_into().map_err(|_| Error::InvalidInput)?; - for _ in 1..value_len { + for _ in 1..len { let (next, next_remainder) = u8::tls_deserialize_bytes(remainder)?; remainder = next_remainder; value = (value << 8) + u64::from(next); } + check_min_len(value, len)?; + Ok((TlsVarInt(value), remainder)) } } From 8ff32d6a17963c72d948a5c6bdb7d009cef940f8 Mon Sep 17 00:00:00 2001 From: boxdot Date: Thu, 13 Feb 2025 11:59:36 +0100 Subject: [PATCH 3/5] split tests in std and non-std parts --- tls_codec/src/varint.rs | 46 ++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/tls_codec/src/varint.rs b/tls_codec/src/varint.rs index a300468c5..f012c0e03 100644 --- a/tls_codec/src/varint.rs +++ b/tls_codec/src/varint.rs @@ -204,8 +204,6 @@ pub(crate) fn calculate_value(byte: u8) -> Result<(usize, usize), Error> { #[cfg(test)] mod tests { - use std::io::Cursor; - use std::vec::Vec; use super::*; @@ -227,7 +225,7 @@ mod tests { ]; #[test] - fn test_size() { + fn tls_serialized_len() { for (value, len, _) in TESTS { assert_eq!( TlsVarInt::try_from(value) @@ -238,8 +236,11 @@ mod tests { } } + #[cfg(feature = "std")] #[test] - fn test_tls_serde() { + fn tls_serialize() { + use crate::alloc::vec::Vec; + for (value, len, bytes) in TESTS { let mut buf = Vec::new(); let written = TlsVarInt::try_from(value) @@ -249,11 +250,13 @@ mod tests { assert_eq!(written, len, "{value}"); assert_eq!(buf.len(), len, "{value}"); assert_eq!(&buf[..], bytes, "{value}"); + } + } - let out = TlsVarInt::tls_deserialize(&mut Cursor::new(bytes)) - .expect("tls deserialize failed"); - assert_eq!(out, TlsVarInt::try_from(value).expect("value too large")); - + #[test] + fn tls_deserialize_bytes() { + for (value, len, bytes) in TESTS { + assert_eq!(len, bytes.len()); let (out, remainder) = TlsVarInt::tls_deserialize_bytes(bytes).expect("tls deserialize bytes failed"); assert_eq!(remainder.len(), 0); @@ -261,15 +264,34 @@ mod tests { } } + #[cfg(feature = "std")] #[test] - fn test_non_canonical_encoding() { - let out = TlsVarInt::tls_deserialize(&mut Cursor::new(&[0x40, 0x25])) - .expect("tls deserialize failed"); - assert_eq!(out, TlsVarInt(37)); + fn tls_deserialize() { + use std::io::Cursor; + + for (value, len, bytes) in TESTS { + assert_eq!(len, bytes.len()); + let out = TlsVarInt::tls_deserialize(&mut Cursor::new(bytes)) + .expect("tls deserialize failed"); + assert_eq!(out, TlsVarInt::try_from(value).expect("value too large")); + } + } + #[test] + fn non_canonical_encoding_deserialize_bytes() { let (out, remaining) = TlsVarInt::tls_deserialize_bytes(&[0x40, 0x25]).expect("tls deserialize bytes failed"); assert_eq!(remaining.len(), 0); assert_eq!(out, TlsVarInt(37)); } + + #[cfg(feature = "std")] + #[test] + fn non_canonical_encoding_deserialize() { + use std::io::Cursor; + + let out = TlsVarInt::tls_deserialize(&mut Cursor::new(&[0x40, 0x25])) + .expect("tls deserialize failed"); + assert_eq!(out, TlsVarInt(37)); + } } From 2e7e024e30f3c9f4b15f1d84ddd1152126c4b752 Mon Sep 17 00:00:00 2001 From: boxdot Date: Thu, 13 Feb 2025 12:07:53 +0100 Subject: [PATCH 4/5] non minimum-size encoding should panic when MLS flag is enabled --- tls_codec/src/varint.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tls_codec/src/varint.rs b/tls_codec/src/varint.rs index f012c0e03..efa5adbd8 100644 --- a/tls_codec/src/varint.rs +++ b/tls_codec/src/varint.rs @@ -278,6 +278,9 @@ mod tests { } #[test] + // Note: MLS requires minimum-size encoding + // + #[cfg_attr(feature = "mls", should_panic)] fn non_canonical_encoding_deserialize_bytes() { let (out, remaining) = TlsVarInt::tls_deserialize_bytes(&[0x40, 0x25]).expect("tls deserialize bytes failed"); @@ -287,6 +290,9 @@ mod tests { #[cfg(feature = "std")] #[test] + // Note: MLS requires minimum-size encoding + // + #[cfg_attr(feature = "mls", should_panic)] fn non_canonical_encoding_deserialize() { use std::io::Cursor; From 999625176e598543770177fecff6b2b5976ec06b Mon Sep 17 00:00:00 2001 From: boxdot Date: Thu, 13 Feb 2025 12:14:54 +0100 Subject: [PATCH 5/5] better test name --- tls_codec/src/varint.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tls_codec/src/varint.rs b/tls_codec/src/varint.rs index efa5adbd8..a1e844195 100644 --- a/tls_codec/src/varint.rs +++ b/tls_codec/src/varint.rs @@ -281,7 +281,7 @@ mod tests { // Note: MLS requires minimum-size encoding // #[cfg_attr(feature = "mls", should_panic)] - fn non_canonical_encoding_deserialize_bytes() { + fn non_minimum_size_deserialize_bytes() { let (out, remaining) = TlsVarInt::tls_deserialize_bytes(&[0x40, 0x25]).expect("tls deserialize bytes failed"); assert_eq!(remaining.len(), 0); @@ -293,7 +293,7 @@ mod tests { // Note: MLS requires minimum-size encoding // #[cfg_attr(feature = "mls", should_panic)] - fn non_canonical_encoding_deserialize() { + fn non_minimum_size_tls_deserialize() { use std::io::Cursor; let out = TlsVarInt::tls_deserialize(&mut Cursor::new(&[0x40, 0x25]))