diff --git a/tls_codec/README.md b/tls_codec/README.md index be0b94f89..d8af87e0f 100644 --- a/tls_codec/README.md +++ b/tls_codec/README.md @@ -19,7 +19,7 @@ derived. The crate also provides the following data structures that implement TLS serialization/deserialization -- `u8`, `u16`, `u32`, `u64` +- `u8`, `u16`, `u32`, `u64`, `TlsVarInt` - `TlsVecU8`, `TlsVecU16`, `TlsVecU32` - `SecretTlsVecU8`, `SecretTlsVecU16`, `SecretTlsVecU32` The same as the `TlsVec*` versions but it implements zeroize, requiring diff --git a/tls_codec/src/lib.rs b/tls_codec/src/lib.rs index e662691d7..51016812f 100644 --- a/tls_codec/src/lib.rs +++ b/tls_codec/src/lib.rs @@ -39,6 +39,7 @@ mod arrays; mod primitives; mod quic_vec; mod tls_vec; +mod varint; pub use tls_vec::{ SecretTlsVecU16, SecretTlsVecU24, SecretTlsVecU32, SecretTlsVecU8, TlsByteSliceU16, @@ -59,6 +60,8 @@ pub use tls_codec_derive::{ #[cfg(feature = "conditional_deserialization")] pub use tls_codec_derive::conditionally_deserializable; +pub use varint::TlsVarInt; + /// Errors that are thrown by this crate. #[derive(Debug, Eq, PartialEq, Clone)] pub enum Error { diff --git a/tls_codec/src/quic_vec.rs b/tls_codec/src/quic_vec.rs index eb77fb479..14e26f3f4 100644 --- a/tls_codec/src/quic_vec.rs +++ b/tls_codec/src/quic_vec.rs @@ -24,115 +24,47 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize}; use crate::{DeserializeBytes, Error, SerializeBytes, Size}; -#[cfg(not(feature = "mls"))] -const MAX_LEN: u64 = (1 << 62) - 1; -#[cfg(not(feature = "mls"))] -const MAX_LEN_LEN_LOG: usize = 3; #[cfg(feature = "mls")] -const MAX_LEN: u64 = (1 << 30) - 1; -#[cfg(feature = "mls")] -const MAX_LEN_LEN_LOG: usize = 2; - -#[inline(always)] -fn check_min_length(length: usize, len_len: usize) -> Result<(), Error> { - if cfg!(feature = "mls") { - // ensure that len_len is minimal for the given length - let min_len_len = length_encoding_bytes(length as u64)?; - if min_len_len != len_len { - return Err(Error::InvalidVectorLength); - } - }; - Ok(()) -} +const MAX_MLS_LEN: u64 = (1 << 30) - 1; -#[inline(always)] -fn calculate_length(len_len_byte: u8) -> Result<(usize, usize), Error> { - let length: usize = (len_len_byte & 0x3F).into(); - let len_len_log = (len_len_byte >> 6).into(); - if !cfg!(fuzzing) { - debug_assert!(len_len_log <= MAX_LEN_LEN_LOG); - } - if len_len_log > MAX_LEN_LEN_LOG { - return Err(Error::InvalidVectorLength); - } - let len_len = match len_len_log { - 0 => 1, - 1 => 2, - 2 => 4, - 3 => 8, - _ => unreachable!(), - }; - Ok((length, len_len)) -} - -#[inline(always)] -fn read_variable_length_bytes(bytes: &[u8]) -> Result<((usize, usize), &[u8]), Error> { - // The length is encoded in the first two bits of the first byte. +/// Thin wrapper around [`TlsVarInt`] representing the length of encoded vector content in bytes. +/// +/// When `mls` feature is enabled, the maximum length is limited to 30-bit. Otherwise, this type is +/// no-op. +struct ContentLength(super::TlsVarInt); - let (len_len_byte, mut remainder) = u8::tls_deserialize_bytes(bytes)?; +impl ContentLength { + #[cfg(not(feature = "mls"))] + #[allow(dead_code)] // used in arbitrary + const MAX: u64 = crate::TlsVarInt::MAX; - let (mut length, len_len) = calculate_length(len_len_byte)?; + #[cfg(feature = "mls")] + const MAX: u64 = MAX_MLS_LEN; - for _ in 1..len_len { - let (next, next_remainder) = u8::tls_deserialize_bytes(remainder)?; - remainder = next_remainder; - length = (length << 8) + usize::from(next); + fn new(value: super::TlsVarInt) -> Result<Self, Error> { + #[cfg(feature = "mls")] + if Self::MAX < value.value() { + return Err(Error::InvalidVectorLength); + } + Ok(Self(value)) } - check_min_length(length, len_len)?; - - Ok(((length, len_len), remainder)) + fn from_usize(value: usize) -> Result<Self, Error> { + Self::new(super::TlsVarInt::try_new(value.try_into()?)?) + } } -#[inline(always)] -fn length_encoding_bytes(length: u64) -> Result<usize, Error> { - if !cfg!(fuzzing) { - debug_assert!(length <= MAX_LEN); - } - if length > MAX_LEN { - return Err(Error::InvalidVectorLength); +impl Size for ContentLength { + fn tls_serialized_len(&self) -> usize { + self.0.tls_serialized_len() } - - Ok(if length <= 0x3f { - 1 - } else if length <= 0x3fff { - 2 - } else if length <= 0x3fff_ffff { - 4 - } else { - 8 - }) } -#[inline(always)] -pub fn write_variable_length(content_length: usize) -> Result<Vec<u8>, Error> { - let len_len = length_encoding_bytes(content_length.try_into()?)?; - if !cfg!(fuzzing) { - debug_assert!(len_len <= 8, "Invalid vector len_len {len_len}"); - } - if len_len > 8 { - return Err(Error::LibraryError); - } - let mut length_bytes = vec![0u8; len_len]; - match len_len { - 1 => length_bytes[0] = 0x00, - 2 => length_bytes[0] = 0x40, - 4 => length_bytes[0] = 0x80, - 8 => length_bytes[0] = 0xc0, - _ => { - if !cfg!(fuzzing) { - debug_assert!(false, "Invalid vector len_len {len_len}"); - } - return Err(Error::InvalidVectorLength); - } - } - let mut len = content_length; - for b in length_bytes.iter_mut().rev() { - *b |= (len & 0xFF) as u8; - len >>= 8; +impl DeserializeBytes for ContentLength { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { + let (value, remainder) = super::TlsVarInt::tls_deserialize_bytes(bytes)?; + Ok((Self(value), remainder)) } - - Ok(length_bytes) } impl<T: Size> Size for Vec<T> { @@ -152,7 +84,9 @@ impl<T: Size> Size for &Vec<T> { impl<T: DeserializeBytes> DeserializeBytes for Vec<T> { #[inline(always)] fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { - let ((length, len_len), mut remainder) = read_variable_length_bytes(bytes)?; + let (length, mut remainder) = ContentLength::tls_deserialize_bytes(bytes)?; + let len_len = length.0.bytes_len(); + let length: usize = length.0.value().try_into()?; if length == 0 { // An empty vector. @@ -178,11 +112,12 @@ impl<T: SerializeBytes> SerializeBytes for &[T] { // This requires more computations but the other option would be to buffer // the entire content, which can end up requiring a lot of memory. let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len()); - let mut length = write_variable_length(content_length)?; - let len_len = length.len(); + let length = ContentLength::from_usize(content_length)?; + let len_len = length.0.bytes_len(); let mut out = Vec::with_capacity(content_length + len_len); - out.append(&mut length); + out.resize(len_len, 0); + length.0.write_bytes(&mut out)?; // Serialize the elements for e in self.iter() { @@ -214,11 +149,13 @@ impl<T: Size> Size for &[T] { #[inline(always)] fn tls_serialized_len(&self) -> usize { let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len()); - let len_len = length_encoding_bytes(content_length as u64).unwrap_or({ - // We can't do anything about the error unless we change the trait. - // Let's say there's no content for now. - 0 - }); + let len_len = ContentLength::from_usize(content_length) + .map(|content_length| content_length.0.bytes_len()) + .unwrap_or({ + // We can't do anything about the error unless we change the trait. + // Let's say there's no content for now. + 0 + }); content_length + len_len } } @@ -327,10 +264,12 @@ impl From<VLBytes> for Vec<u8> { #[inline(always)] fn tls_serialize_bytes_len(bytes: &[u8]) -> usize { let content_length = bytes.len(); - let len_len = length_encoding_bytes(content_length as u64).unwrap_or({ - // We can't do anything about the error. Let's say there's no content. - 0 - }); + let len_len = ContentLength::from_usize(content_length) + .map(|content_length| content_length.0.bytes_len()) + .unwrap_or({ + // We can't do anything about the error. Let's say there's no content. + 0 + }); content_length + len_len } @@ -344,22 +283,13 @@ impl Size for VLBytes { impl DeserializeBytes for VLBytes { #[inline(always)] fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { - let ((length, _), remainder) = read_variable_length_bytes(bytes)?; + let (length, remainder) = ContentLength::tls_deserialize_bytes(bytes)?; + let length: usize = length.0.value().try_into()?; + if length == 0 { return Ok((Self::new(vec![]), remainder)); } - if !cfg!(fuzzing) { - debug_assert!( - length <= MAX_LEN as usize, - "Trying to allocate {length} bytes. Only {MAX_LEN} allowed.", - ); - } - if length > MAX_LEN as usize { - return Err(Error::DecodingError(format!( - "Trying to allocate {length} bytes. Only {MAX_LEN} allowed.", - ))); - } match remainder.get(..length).ok_or(Error::EndOfStream) { Ok(vec) => Ok((Self { vec: vec.to_vec() }, &remainder[length..])), Err(_e) => { @@ -422,6 +352,19 @@ pub mod rw { use super::*; use crate::{Deserialize, Serialize}; + impl Deserialize for ContentLength { + fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> { + ContentLength::new(crate::TlsVarInt::tls_deserialize(bytes)?) + } + } + + impl Serialize for ContentLength { + #[inline(always)] + fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> { + self.0.tls_serialize(writer) + } + } + /// Read the length of a variable-length vector. /// /// This function assumes that the reader is at the start of a variable length @@ -430,26 +373,9 @@ pub mod rw { /// The length and number of bytes read are returned. #[inline] pub fn read_length<R: std::io::Read>(bytes: &mut R) -> Result<(usize, usize), Error> { - // The length is encoded in the first two bits of the first byte. - let mut len_len_byte = [0u8; 1]; - if bytes.read(&mut len_len_byte)? == 0 { - // There must be at least one byte for the length. - // If we don't even have a length byte, this is not a valid - // variable-length encoded vector. - return Err(Error::InvalidVectorLength); - } - let len_len_byte = len_len_byte[0]; - - let (mut length, len_len) = calculate_length(len_len_byte)?; - - for _ in 1..len_len { - let mut next = [0u8; 1]; - bytes.read_exact(&mut next)?; - length = (length << 8) + usize::from(next[0]); - } - - check_min_length(length, len_len)?; - + let length = ContentLength::tls_deserialize(bytes)?; + let len_len = length.0.bytes_len(); + let length: usize = length.0.value().try_into()?; Ok((length, len_len)) } @@ -479,10 +405,7 @@ pub mod rw { writer: &mut W, content_length: usize, ) -> Result<usize, Error> { - let buf = super::write_variable_length(content_length)?; - let buf_len = buf.len(); - writer.write_all(&buf)?; - Ok(buf_len) + ContentLength::from_usize(content_length)?.tls_serialize(writer) } impl<T: Serialize + std::fmt::Debug> Serialize for Vec<T> { @@ -538,19 +461,7 @@ mod rw_bytes { // large and write it out. let content_length = bytes.len(); - if !cfg!(fuzzing) { - debug_assert!( - content_length as u64 <= MAX_LEN, - "Vector can't be encoded. It's too large. {content_length} >= {MAX_LEN}", - ); - } - if content_length as u64 > MAX_LEN { - return Err(Error::InvalidVectorLength); - } - - let length_bytes = write_variable_length(content_length)?; - let len_len = length_bytes.len(); - writer.write_all(&length_bytes)?; + let len_len = ContentLength::from_usize(content_length)?.tls_serialize(writer)?; // Now serialize the elements writer.write_all(bytes)?; @@ -574,24 +485,14 @@ mod rw_bytes { impl Deserialize for VLBytes { fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> { - let (length, _) = rw::read_length(bytes)?; - if length == 0 { + let length = ContentLength::tls_deserialize(bytes)?; + + if length.0.value() == 0 { return Ok(Self::new(vec![])); } - if !cfg!(fuzzing) { - debug_assert!( - length <= MAX_LEN as usize, - "Trying to allocate {length} bytes. Only {MAX_LEN} allowed.", - ); - } - if length > MAX_LEN as usize { - return Err(Error::DecodingError(format!( - "Trying to allocate {length} bytes. Only {MAX_LEN} allowed.", - ))); - } let mut result = Self { - vec: vec![0u8; length], + vec: vec![0u8; length.0.value().try_into()?], }; bytes.read_exact(result.vec.as_mut_slice())?; Ok(result) @@ -682,7 +583,7 @@ impl<'a> Arbitrary<'a> for VLBytes { // We generate an arbitrary `Vec<u8>` ... let mut vec = Vec::arbitrary(u)?; // ... and truncate it to `MAX_LEN`. - vec.truncate(MAX_LEN as usize); + vec.truncate(ContentLength::MAX as usize); // We probably won't exceed `MAX_LEN` in practice, e.g., during fuzzing, // but better make sure that we generate valid instances. diff --git a/tls_codec/src/varint.rs b/tls_codec/src/varint.rs new file mode 100644 index 000000000..a1e844195 --- /dev/null +++ b/tls_codec/src/varint.rs @@ -0,0 +1,303 @@ +use crate::{Deserialize, DeserializeBytes, Error, Serialize, Size}; + +/// Variable-length encoded unsigned integer as defined in [RFC 9000]. +/// +/// [RFC 9000]: https://www.rfc-editor.org/rfc/rfc9000#name-variable-length-integer-enc +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct TlsVarInt(u64); + +impl TlsVarInt { + /// The largest value that can be represented by this type. + pub const MAX: u64 = (1 << 62) - 1; + const MAX_LOG: usize = 3; + + /// Wraps an unsinged integer as variable-length int. + /// + /// Returns `None` if the value is larger than [`Self::MAX`]. + pub const fn new(value: u64) -> Option<Self> { + if Self::MAX < value { + None + } else { + Some(Self(value)) + } + } + + pub(crate) fn try_new(value: u64) -> Result<Self, Error> { + Self::new(value).ok_or(Error::InvalidVectorLength) + } + + /// Returns the value of this variable-length int. + pub const fn value(&self) -> u64 { + self.0 + } + + /// Returns the number of bytes required to encode this variable-length int. + pub(crate) const fn bytes_len(&self) -> usize { + let value = self.0; + if !cfg!(fuzzing) { + debug_assert!(value <= Self::MAX); + } + if value <= 0x3f { + 1 + } else if value <= 0x3fff { + 2 + } else if value <= 0x3fff_ffff { + 4 + } else { + 8 + } + } + + /// Writes the bytes of this variable-length at the beginning of the buffer. + /// + /// The buffer must be at least of the length returned by [`Self::bytes_len`]. + pub(crate) fn write_bytes(&self, buf: &mut [u8]) -> Result<usize, Error> { + let len = self.bytes_len(); + if !cfg!(fuzzing) { + debug_assert!(len <= 8, "Invalid varint len {len}"); + } + if len > 8 { + return Err(Error::LibraryError); + } + + if buf.len() < len { + return Err(Error::InvalidVectorLength); + } + let bytes = &mut buf[..len]; + + match len { + 1 => bytes[0] = 0x00, + 2 => bytes[0] = 0x40, + 4 => bytes[0] = 0x80, + 8 => bytes[0] = 0xc0, + _ => { + if !cfg!(fuzzing) { + debug_assert!(false, "Invalid varint len {len}"); + } + return Err(Error::InvalidVectorLength); + } + } + let mut value = self.0; + for b in bytes.iter_mut().rev() { + *b |= (value & 0xFF) as u8; + value >>= 8; + } + + Ok(len) + } +} + +impl TryFrom<u64> for TlsVarInt { + type Error = Error; + + #[inline] + fn try_from(value: u64) -> Result<Self, Self::Error> { + Self::try_new(value) + } +} + +impl From<TlsVarInt> for u64 { + #[inline] + fn from(value: TlsVarInt) -> Self { + value.0 + } +} + +#[inline(always)] +fn check_min_len(value: u64, len: usize) -> Result<(), Error> { + if cfg!(feature = "mls") { + // ensure that `len` is minimal for the given `value` + let min_len = TlsVarInt::try_new(value)?.bytes_len(); + if min_len != len { + return Err(Error::InvalidVectorLength); + } + }; + Ok(()) +} + +impl Deserialize for TlsVarInt { + #[cfg(feature = "std")] + #[inline] + fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> { + let mut len_byte = [0u8; 1]; + if bytes.read(&mut len_byte)? == 0 { + return Err(Error::EndOfStream); + }; + let len_byte = len_byte[0]; + + let (value, len) = calculate_value(len_byte)?; + let mut value: u64 = value.try_into().map_err(|_| Error::InvalidInput)?; + + for _ in 1..len { + let mut next = [0u8; 1]; + bytes.read_exact(&mut next)?; + value = (value << 8) + u64::from(next[0]); + } + + check_min_len(value, len)?; + + Ok(TlsVarInt(value)) + } +} + +impl DeserializeBytes for TlsVarInt { + #[inline] + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> + where + Self: Sized, + { + let (len_byte, mut remainder) = u8::tls_deserialize_bytes(bytes)?; + + let (value, len) = calculate_value(len_byte)?; + let mut value: u64 = value.try_into().map_err(|_| Error::InvalidInput)?; + + for _ in 1..len { + let (next, next_remainder) = u8::tls_deserialize_bytes(remainder)?; + remainder = next_remainder; + value = (value << 8) + u64::from(next); + } + + check_min_len(value, len)?; + + Ok((TlsVarInt(value), remainder)) + } +} + +impl Serialize for TlsVarInt { + #[cfg(feature = "std")] + #[inline] + fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> { + let mut bytes = [0u8; 8]; + let len = self.write_bytes(&mut bytes)?; + writer.write_all(&bytes[..len])?; + Ok(len) + } +} + +impl Size for TlsVarInt { + #[inline] + fn tls_serialized_len(&self) -> usize { + self.bytes_len() + } +} + +/// Calculates the value and the length from the first byte. +#[inline(always)] +pub(crate) fn calculate_value(byte: u8) -> Result<(usize, usize), Error> { + let value: usize = (byte & 0x3F).into(); + let len_log = (byte >> 6).into(); + if !cfg!(fuzzing) { + debug_assert!(len_log <= TlsVarInt::MAX_LOG); + } + if len_log > TlsVarInt::MAX_LOG { + return Err(Error::InvalidVectorLength); + } + let len = match len_log { + 0 => 1, + 1 => 2, + 2 => 4, + 3 => 8, + _ => unreachable!(), + }; + Ok((value, len)) +} + +#[cfg(test)] +mod tests { + + use super::*; + + // (value, var length, encoded bytes) + const TESTS: [(u64, usize, &[u8]); 5] = [ + (37, 1, &[0x25]), + (15_293, 2, &[0x7b, 0xbd]), + (494_878_333, 4, &[0x9d, 0x7f, 0x3e, 0x7d]), + ( + 151_288_809_941_952_652, + 8, + &[0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c], + ), + ( + TlsVarInt::MAX, + 8, + &[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], + ), + ]; + + #[test] + fn tls_serialized_len() { + for (value, len, _) in TESTS { + assert_eq!( + TlsVarInt::try_from(value) + .expect("value too large") + .tls_serialized_len(), + len + ); + } + } + + #[cfg(feature = "std")] + #[test] + fn tls_serialize() { + use crate::alloc::vec::Vec; + + for (value, len, bytes) in TESTS { + let mut buf = Vec::new(); + let written = TlsVarInt::try_from(value) + .expect("value too large") + .tls_serialize(&mut buf) + .expect("tls serialize failed"); + assert_eq!(written, len, "{value}"); + assert_eq!(buf.len(), len, "{value}"); + assert_eq!(&buf[..], bytes, "{value}"); + } + } + + #[test] + fn tls_deserialize_bytes() { + for (value, len, bytes) in TESTS { + assert_eq!(len, bytes.len()); + let (out, remainder) = + TlsVarInt::tls_deserialize_bytes(bytes).expect("tls deserialize bytes failed"); + assert_eq!(remainder.len(), 0); + assert_eq!(out, TlsVarInt::try_from(value).expect("value too large")); + } + } + + #[cfg(feature = "std")] + #[test] + fn tls_deserialize() { + use std::io::Cursor; + + for (value, len, bytes) in TESTS { + assert_eq!(len, bytes.len()); + let out = TlsVarInt::tls_deserialize(&mut Cursor::new(bytes)) + .expect("tls deserialize failed"); + assert_eq!(out, TlsVarInt::try_from(value).expect("value too large")); + } + } + + #[test] + // Note: MLS requires minimum-size encoding + // <https://www.rfc-editor.org/rfc/rfc9420.html#name-variable-size-vector-length> + #[cfg_attr(feature = "mls", should_panic)] + fn non_minimum_size_deserialize_bytes() { + let (out, remaining) = + TlsVarInt::tls_deserialize_bytes(&[0x40, 0x25]).expect("tls deserialize bytes failed"); + assert_eq!(remaining.len(), 0); + assert_eq!(out, TlsVarInt(37)); + } + + #[cfg(feature = "std")] + #[test] + // Note: MLS requires minimum-size encoding + // <https://www.rfc-editor.org/rfc/rfc9420.html#name-variable-size-vector-length> + #[cfg_attr(feature = "mls", should_panic)] + fn non_minimum_size_tls_deserialize() { + use std::io::Cursor; + + let out = TlsVarInt::tls_deserialize(&mut Cursor::new(&[0x40, 0x25])) + .expect("tls deserialize failed"); + assert_eq!(out, TlsVarInt(37)); + } +}