diff --git a/elliptic-curve/src/hash2curve/group_digest.rs b/elliptic-curve/src/hash2curve/group_digest.rs index 2a663a11..8fa27c46 100644 --- a/elliptic-curve/src/hash2curve/group_digest.rs +++ b/elliptic-curve/src/hash2curve/group_digest.rs @@ -3,6 +3,7 @@ use super::{ExpandMsg, FromOkm, MapToCurve, hash_to_field}; use crate::{CurveArithmetic, ProjectivePoint, Result}; use group::cofactor::CofactorGroup; +use hybrid_array::typenum::Unsigned; /// Adds hashing arbitrary byte sequences to a valid group element pub trait GroupDigest: CurveArithmetic @@ -12,6 +13,11 @@ where /// The field element representation for a group value with multiple elements type FieldElement: FromOkm + MapToCurve> + Default + Copy; + /// The target security level in bytes: + /// + /// + type K: Unsigned; + /// Computes the hash to curve routine. /// /// From : diff --git a/elliptic-curve/src/hash2curve/hash2field.rs b/elliptic-curve/src/hash2curve/hash2field.rs index 946c8a39..4ffcf806 100644 --- a/elliptic-curve/src/hash2curve/hash2field.rs +++ b/elliptic-curve/src/hash2curve/hash2field.rs @@ -4,15 +4,20 @@ mod expand_msg; +use core::num::NonZeroUsize; + pub use expand_msg::{xmd::*, xof::*, *}; use crate::{Error, Result}; -use hybrid_array::{Array, ArraySize, typenum::Unsigned}; +use hybrid_array::{ + Array, ArraySize, + typenum::{NonZero, Unsigned}, +}; /// The trait for helping to convert to a field element. pub trait FromOkm { /// The number of bytes needed to convert to a field element. - type Length: ArraySize; + type Length: ArraySize + NonZero; /// Convert a byte sequence into a field element. fn from_okm(data: &Array) -> Self; @@ -37,7 +42,10 @@ where E: ExpandMsg<'a>, T: FromOkm + Default, { - let len_in_bytes = T::Length::to_usize().checked_mul(out.len()).ok_or(Error)?; + let len_in_bytes = T::Length::to_usize() + .checked_mul(out.len()) + .and_then(NonZeroUsize::new) + .ok_or(Error)?; let mut tmp = Array::::Length>::default(); let mut expander = E::expand_message(data, domain, len_in_bytes)?; for o in out.iter_mut() { diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs index 510ce5b2..444dbf72 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs @@ -3,6 +3,8 @@ pub(super) mod xmd; pub(super) mod xof; +use core::num::NonZero; + use crate::{Error, Result}; use digest::{Digest, ExtendableOutput, Update, XofReader}; use hybrid_array::typenum::{IsLess, U256}; @@ -28,7 +30,7 @@ pub trait ExpandMsg<'a> { fn expand_message( msgs: &[&[u8]], dsts: &'a [&'a [u8]], - len_in_bytes: usize, + len_in_bytes: NonZero, ) -> Result; } diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs index c1cb250b..e98843e3 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs @@ -1,6 +1,6 @@ //! `expand_message_xmd` based on a hash function. -use core::marker::PhantomData; +use core::{marker::PhantomData, num::NonZero, ops::Mul}; use super::{Domain, ExpandMsg, Expander}; use crate::{Error, Result}; @@ -8,52 +8,62 @@ use digest::{ FixedOutput, HashMarker, array::{ Array, - typenum::{IsLess, IsLessOrEqual, U256, Unsigned}, + typenum::{IsGreaterOrEqual, IsLess, IsLessOrEqual, U2, U256, Unsigned}, }, core_api::BlockSizeUser, }; -/// Placeholder type for implementing `expand_message_xmd` based on a hash function +/// Implements `expand_message_xof` via the [`ExpandMsg`] trait: +/// +/// +/// `K` is the target security level in bytes: +/// +/// /// /// # Errors /// - `dst.is_empty()` -/// - `len_in_bytes == 0` /// - `len_in_bytes > u16::MAX` /// - `len_in_bytes > 255 * HashT::OutputSize` #[derive(Debug)] -pub struct ExpandMsgXmd(PhantomData) +pub struct ExpandMsgXmd(PhantomData<(HashT, K)>) where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess, - HashT::OutputSize: IsLessOrEqual; + HashT::OutputSize: IsLessOrEqual, + K: Mul, + HashT::OutputSize: IsGreaterOrEqual<>::Output>; -/// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait -impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXmd +impl<'a, HashT, K> ExpandMsg<'a> for ExpandMsgXmd where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, - // If `len_in_bytes` is bigger then 256, length of the `DST` will depend on - // the output size of the hash, which is still not allowed to be bigger then 256: + // If DST is larger than 255 bytes, the length of the computed DST will depend on the output + // size of the hash, which is still not allowed to be larger than 256: // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6 HashT::OutputSize: IsLess, // Constraint set by `expand_message_xmd`: // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-4 HashT::OutputSize: IsLessOrEqual, + // The number of bits output by `HashT` MUST be larger or equal to `K * 2`: + // https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-2.1 + K: Mul, + HashT::OutputSize: IsGreaterOrEqual<>::Output>, { type Expander = ExpanderXmd<'a, HashT>; fn expand_message( msgs: &[&[u8]], dsts: &'a [&'a [u8]], - len_in_bytes: usize, + len_in_bytes: NonZero, ) -> Result { - if len_in_bytes == 0 { + let len_in_bytes_u16 = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?; + + // `255 * ` can not exceed `u16::MAX` + if len_in_bytes_u16 > 255 * HashT::OutputSize::to_u16() { return Err(Error); } - let len_in_bytes_u16 = u16::try_from(len_in_bytes).map_err(|_| Error)?; - let b_in_bytes = HashT::OutputSize::to_usize(); - let ell = u8::try_from(len_in_bytes.div_ceil(b_in_bytes)).map_err(|_| Error)?; + let ell = u8::try_from(len_in_bytes.get().div_ceil(b_in_bytes)).map_err(|_| Error)?; let domain = Domain::xmd::(dsts)?; let mut b_0 = HashT::default(); @@ -157,7 +167,7 @@ mod test { use hex_literal::hex; use hybrid_array::{ ArraySize, - typenum::{U32, U128}, + typenum::{U4, U8, U32, U128}, }; use sha2::Sha256; @@ -209,13 +219,17 @@ mod test { ) -> Result<()> where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, - HashT::OutputSize: IsLess + IsLessOrEqual, + HashT::OutputSize: IsLess + IsLessOrEqual + Mul, + HashT::OutputSize: IsGreaterOrEqual<>::Output>, { assert_message::(self.msg, domain, L::to_u16(), self.msg_prime); let dst = [dst]; - let mut expander = - ExpandMsgXmd::::expand_message(&[self.msg], &dst, L::to_usize())?; + let mut expander = ExpandMsgXmd::::expand_message( + &[self.msg], + &dst, + NonZero::new(L::to_usize()).ok_or(Error)?, + )?; let mut uniform_bytes = Array::::default(); expander.fill_bytes(&mut uniform_bytes); diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs index 6a5c1462..9d40ed2c 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs @@ -2,26 +2,38 @@ use super::{Domain, ExpandMsg, Expander}; use crate::{Error, Result}; -use core::fmt; -use digest::{ExtendableOutput, Update, XofReader}; -use hybrid_array::typenum::U32; - -/// Placeholder type for implementing `expand_message_xof` based on an extendable output function +use core::{fmt, marker::PhantomData, num::NonZero, ops::Mul}; +use digest::{ExtendableOutput, HashMarker, Update, XofReader}; +use hybrid_array::{ + ArraySize, + typenum::{IsLess, U2, U256}, +}; + +/// Implements `expand_message_xof` via the [`ExpandMsg`] trait: +/// +/// +/// `K` is the target security level in bytes: +/// +/// /// /// # Errors /// - `dst.is_empty()` -/// - `len_in_bytes == 0` /// - `len_in_bytes > u16::MAX` -pub struct ExpandMsgXof +pub struct ExpandMsgXof where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, + K: Mul, + >::Output: ArraySize + IsLess, { reader: ::Reader, + _k: PhantomData, } -impl fmt::Debug for ExpandMsgXof +impl fmt::Debug for ExpandMsgXof where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, + K: Mul, + >::Output: ArraySize + IsLess, ::Reader: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -31,25 +43,24 @@ where } } -/// ExpandMsgXof implements `expand_message_xof` for the [`ExpandMsg`] trait -impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXof +impl<'a, HashT, K> ExpandMsg<'a> for ExpandMsgXof where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, + // If DST is larger than 255 bytes, the length of the computed DST is calculated by `K * 2`. + // https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-2.1 + K: Mul, + >::Output: ArraySize + IsLess, { type Expander = Self; fn expand_message( msgs: &[&[u8]], dsts: &'a [&'a [u8]], - len_in_bytes: usize, + len_in_bytes: NonZero, ) -> Result { - if len_in_bytes == 0 { - return Err(Error); - } - - let len_in_bytes = u16::try_from(len_in_bytes).map_err(|_| Error)?; + let len_in_bytes = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?; - let domain = Domain::::xof::(dsts)?; + let domain = Domain::<>::Output>::xof::(dsts)?; let mut reader = HashT::default(); for msg in msgs { @@ -60,13 +71,18 @@ where domain.update_hash(&mut reader); reader.update(&[domain.len()]); let reader = reader.finalize_xof(); - Ok(Self { reader }) + Ok(Self { + reader, + _k: PhantomData, + }) } } -impl Expander for ExpandMsgXof +impl Expander for ExpandMsgXof where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, + K: Mul, + >::Output: ArraySize + IsLess, { fn fill_bytes(&mut self, okm: &mut [u8]) { self.reader.read(okm); @@ -78,7 +94,10 @@ mod test { use super::*; use core::mem::size_of; use hex_literal::hex; - use hybrid_array::{Array, ArraySize, typenum::U128}; + use hybrid_array::{ + Array, ArraySize, + typenum::{U16, U32, U128}, + }; use sha3::Shake128; fn assert_message(msg: &[u8], domain: &Domain<'_, U32>, len_in_bytes: u16, bytes: &[u8]) { @@ -110,13 +129,16 @@ mod test { #[allow(clippy::panic_in_result_fn)] fn assert(&self, dst: &'static [u8], domain: &Domain<'_, U32>) -> Result<()> where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, L: ArraySize, { assert_message(self.msg, domain, L::to_u16(), self.msg_prime); - let mut expander = - ExpandMsgXof::::expand_message(&[self.msg], &[dst], L::to_usize())?; + let mut expander = ExpandMsgXof::::expand_message( + &[self.msg], + &[dst], + NonZero::new(L::to_usize()).ok_or(Error)?, + )?; let mut uniform_bytes = Array::::default(); expander.fill_bytes(&mut uniform_bytes);