diff --git a/Cargo.lock b/Cargo.lock index 11574624..aa3317d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -436,6 +436,7 @@ dependencies = [ name = "scrypt" version = "0.12.0-rc.1" dependencies = [ + "cfg-if", "password-hash", "pbkdf2", "salsa20", diff --git a/scrypt/Cargo.toml b/scrypt/Cargo.toml index 5d05607f..ed86407c 100644 --- a/scrypt/Cargo.toml +++ b/scrypt/Cargo.toml @@ -14,6 +14,7 @@ edition = "2024" rust-version = "1.85" [dependencies] +cfg-if = "1.0" pbkdf2 = { version = "0.13.0-rc.0", path = "../pbkdf2" } salsa20 = { version = "0.11.0-rc.0", default-features = false } sha2 = { version = "0.11.0-rc.0", default-features = false } diff --git a/scrypt/src/block_mix.rs b/scrypt/src/block_mix.rs new file mode 100644 index 00000000..66082888 --- /dev/null +++ b/scrypt/src/block_mix.rs @@ -0,0 +1,56 @@ +cfg_if::cfg_if! { + if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] { + mod pivot; + mod simd128; + pub(crate) use simd128::{scrypt_block_mix, shuffle_in, shuffle_out}; + } else if #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2"))] { + mod pivot; + mod sse2; + pub(crate) use sse2::{scrypt_block_mix, shuffle_in, shuffle_out}; + } else { + mod soft; + pub(crate) use soft::scrypt_block_mix; + + pub(crate) fn shuffle_in(_input: &mut [u8]) {} + pub(crate) fn shuffle_out(_input: &mut [u8]) {} + } +} + +#[cfg(test)] +#[path = "block_mix/soft.rs"] +mod soft_test; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_scrypt_block_mix_abcd_against_soft() { + let mut input: [u8; 128] = core::array::from_fn(|i| i as u8); + for _round in 0..10 { + let mut output = [0u8; 128]; + + let mut expected0 = [0u8; 128]; + let mut expected1 = [0u8; 128]; // check shuffle_out is a correct inverse of shuffle_in + soft_test::scrypt_block_mix(&input, &mut expected0); + shuffle_in(&mut input); + scrypt_block_mix(&input, &mut output); + shuffle_out(&mut input); + soft_test::scrypt_block_mix(&input, &mut expected1); + shuffle_out(&mut output); + assert_eq!( + expected0, expected1, + "expected0 != expected1, shuffle_out is not a correct inverse of shuffle_in?" + ); + assert_eq!( + output, expected0, + "output != expected0, scrypt_block_mix is not correct?" + ); + + input + .iter_mut() + .zip(output.iter()) + .for_each(|(a, b)| *a = a.wrapping_add(*b)); + } + } +} diff --git a/scrypt/src/block_mix/pivot.rs b/scrypt/src/block_mix/pivot.rs new file mode 100644 index 00000000..3839ad70 --- /dev/null +++ b/scrypt/src/block_mix/pivot.rs @@ -0,0 +1,20 @@ +/// Permute Salsa20 block to column major order +pub(crate) const PIVOT_ABCD: [usize; 16] = [0, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12, 1, 6, 11]; + +/// Inverse of PIVOT_ABCD +pub(crate) const INVERSE_PIVOT_ABCD: [usize; 16] = const { + let mut index = [0; 16]; + let mut i = 0; + while i < 16 { + let mut inverse = 0; + while inverse < 16 { + if PIVOT_ABCD[inverse] == i { + index[i] = inverse; + break; + } + inverse += 1; + } + i += 1; + } + index +}; diff --git a/scrypt/src/block_mix/simd128.rs b/scrypt/src/block_mix/simd128.rs new file mode 100644 index 00000000..0757ffc2 --- /dev/null +++ b/scrypt/src/block_mix/simd128.rs @@ -0,0 +1,88 @@ +use crate::block_mix::pivot::{INVERSE_PIVOT_ABCD, PIVOT_ABCD}; + +pub(crate) fn shuffle_in(b: &mut [u8]) { + for chunk in b.chunks_exact_mut(64) { + let mut t = [0u32; 16]; + for (c, b) in chunk.chunks_exact(4).zip(t.iter_mut()) { + *b = u32::from_ne_bytes(c.try_into().unwrap()); + } + chunk.chunks_exact_mut(4).enumerate().for_each(|(i, b)| { + b.copy_from_slice(&t[PIVOT_ABCD[i]].to_ne_bytes()); + }); + } +} + +pub(crate) fn shuffle_out(b: &mut [u8]) { + for chunk in b.chunks_exact_mut(64) { + let mut t = [0u32; 16]; + for (c, b) in chunk.chunks_exact(4).zip(t.iter_mut()) { + *b = u32::from_ne_bytes(c.try_into().unwrap()); + } + chunk.chunks_exact_mut(4).enumerate().for_each(|(i, b)| { + b.copy_from_slice(&t[INVERSE_PIVOT_ABCD[i]].to_ne_bytes()); + }); + } +} + +pub(crate) fn scrypt_block_mix(input: &[u8], output: &mut [u8]) { + use core::arch::wasm32::*; + + macro_rules! u32x4_rol { + ($w:expr, $amt:literal) => {{ + let w = $w; + v128_or(u32x4_shl(w, $amt), u32x4_shr(w, 32 - $amt)) + }}; + } + + let last_block = &input[input.len() - 64..]; + + let mut a = unsafe { v128_load(last_block.as_ptr().cast()) }; + let mut b = unsafe { v128_load(last_block.as_ptr().add(16).cast()) }; + let mut c = unsafe { v128_load(last_block.as_ptr().add(32).cast()) }; + let mut d = unsafe { v128_load(last_block.as_ptr().add(48).cast()) }; + + for (i, chunk) in input.chunks(64).enumerate() { + let pos = if i % 2 == 0 { + (i / 2) * 64 + } else { + (i / 2) * 64 + input.len() / 2 + }; + + unsafe { + let chunk_a = v128_load(chunk.as_ptr().cast()); + let chunk_b = v128_load(chunk.as_ptr().add(16).cast()); + let chunk_c = v128_load(chunk.as_ptr().add(32).cast()); + let chunk_d = v128_load(chunk.as_ptr().add(48).cast()); + + a = v128_xor(a, chunk_a); + b = v128_xor(b, chunk_b); + c = v128_xor(c, chunk_c); + d = v128_xor(d, chunk_d); + + let saves = [a, b, c, d]; + + for _ in 0..8 { + b = v128_xor(b, u32x4_rol!(u32x4_add(a, d), 7)); + c = v128_xor(c, u32x4_rol!(u32x4_add(b, a), 9)); + d = v128_xor(d, u32x4_rol!(u32x4_add(c, b), 13)); + a = v128_xor(a, u32x4_rol!(u32x4_add(d, c), 18)); + + d = i32x4_shuffle::<1, 2, 3, 0>(d, d); + c = i32x4_shuffle::<2, 3, 0, 1>(c, c); + b = i32x4_shuffle::<3, 0, 1, 2>(b, b); + + (b, d) = (d, b); + } + + a = u32x4_add(a, saves[0]); + b = u32x4_add(b, saves[1]); + c = u32x4_add(c, saves[2]); + d = u32x4_add(d, saves[3]); + + v128_store(output.as_mut_ptr().add(pos).cast(), a); + v128_store(output.as_mut_ptr().add(pos + 16).cast(), b); + v128_store(output.as_mut_ptr().add(pos + 32).cast(), c); + v128_store(output.as_mut_ptr().add(pos + 48).cast(), d); + } + } +} diff --git a/scrypt/src/block_mix/soft.rs b/scrypt/src/block_mix/soft.rs new file mode 100644 index 00000000..9dbdd984 --- /dev/null +++ b/scrypt/src/block_mix/soft.rs @@ -0,0 +1,42 @@ +/// Execute the BlockMix operation +/// input - the input vector. The length must be a multiple of 128. +/// output - the output vector. Must be the same length as input. +pub(crate) fn scrypt_block_mix(input: &[u8], output: &mut [u8]) { + use salsa20::{ + SalsaCore, + cipher::{StreamCipherCore, typenum::U4}, + }; + + type Salsa20_8 = SalsaCore; + + let mut x = [0u8; 64]; + x.copy_from_slice(&input[input.len() - 64..]); + + let mut t = [0u8; 64]; + + for (i, chunk) in input.chunks(64).enumerate() { + xor(&x, chunk, &mut t); + + let mut t2 = [0u32; 16]; + + for (c, b) in t.chunks_exact(4).zip(t2.iter_mut()) { + *b = u32::from_le_bytes(c.try_into().unwrap()); + } + + Salsa20_8::from_raw_state(t2).write_keystream_block((&mut x).into()); + + let pos = if i % 2 == 0 { + (i / 2) * 64 + } else { + (i / 2) * 64 + input.len() / 2 + }; + + output[pos..pos + 64].copy_from_slice(&x); + } +} + +fn xor(x: &[u8], y: &[u8], output: &mut [u8]) { + for ((out, &x_i), &y_i) in output.iter_mut().zip(x.iter()).zip(y.iter()) { + *out = x_i ^ y_i; + } +} diff --git a/scrypt/src/block_mix/sse2.rs b/scrypt/src/block_mix/sse2.rs new file mode 100644 index 00000000..1daebe57 --- /dev/null +++ b/scrypt/src/block_mix/sse2.rs @@ -0,0 +1,90 @@ +use crate::block_mix::pivot::{INVERSE_PIVOT_ABCD, PIVOT_ABCD}; + +pub(crate) fn shuffle_in(b: &mut [u8]) { + for chunk in b.chunks_exact_mut(64) { + let mut t = [0u32; 16]; + for (c, b) in chunk.chunks_exact(4).zip(t.iter_mut()) { + *b = u32::from_ne_bytes(c.try_into().unwrap()); + } + chunk.chunks_exact_mut(4).enumerate().for_each(|(i, b)| { + b.copy_from_slice(&t[PIVOT_ABCD[i]].to_ne_bytes()); + }); + } +} + +pub(crate) fn shuffle_out(b: &mut [u8]) { + for chunk in b.chunks_exact_mut(64) { + let mut t = [0u32; 16]; + for (c, b) in chunk.chunks_exact(4).zip(t.iter_mut()) { + *b = u32::from_ne_bytes(c.try_into().unwrap()); + } + chunk.chunks_exact_mut(4).enumerate().for_each(|(i, b)| { + b.copy_from_slice(&t[INVERSE_PIVOT_ABCD[i]].to_ne_bytes()); + }); + } +} + +pub(crate) fn scrypt_block_mix(input: &[u8], output: &mut [u8]) { + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; + + macro_rules! mm_rol_epi32x { + ($w:expr, $amt:literal) => {{ + let w = $w; + _mm_or_si128(_mm_slli_epi32(w, $amt), _mm_srli_epi32(w, 32 - $amt)) + }}; + } + + let last_block = &input[input.len() - 64..]; + + let mut a = unsafe { _mm_loadu_si128(last_block.as_ptr().cast()) }; + let mut b = unsafe { _mm_loadu_si128(last_block.as_ptr().add(16).cast()) }; + let mut c = unsafe { _mm_loadu_si128(last_block.as_ptr().add(32).cast()) }; + let mut d = unsafe { _mm_loadu_si128(last_block.as_ptr().add(48).cast()) }; + + for (i, chunk) in input.chunks(64).enumerate() { + let pos = if i % 2 == 0 { + (i / 2) * 64 + } else { + (i / 2) * 64 + input.len() / 2 + }; + + unsafe { + a = _mm_xor_si128(a, _mm_loadu_si128(chunk.as_ptr().cast())); + b = _mm_xor_si128(b, _mm_loadu_si128(chunk.as_ptr().add(16).cast())); + c = _mm_xor_si128(c, _mm_loadu_si128(chunk.as_ptr().add(32).cast())); + d = _mm_xor_si128(d, _mm_loadu_si128(chunk.as_ptr().add(48).cast())); + + let saves = [a, b, c, d]; + + for _ in 0..8 { + b = _mm_xor_si128(b, mm_rol_epi32x!(_mm_add_epi32(a, d), 7)); + c = _mm_xor_si128(c, mm_rol_epi32x!(_mm_add_epi32(b, a), 9)); + d = _mm_xor_si128(d, mm_rol_epi32x!(_mm_add_epi32(c, b), 13)); + a = _mm_xor_si128(a, mm_rol_epi32x!(_mm_add_epi32(d, c), 18)); + + // a stays in place + // b = left shuffle d by 1 element + d = _mm_shuffle_epi32(d, 0b00_11_10_01); + // c = left shuffle c by 2 elements + c = _mm_shuffle_epi32(c, 0b01_00_11_10); + // d = left shuffle b by 3 elements + b = _mm_shuffle_epi32(b, 0b10_01_00_11); + (b, d) = (d, b); + } + + a = _mm_add_epi32(a, saves[0]); + b = _mm_add_epi32(b, saves[1]); + c = _mm_add_epi32(c, saves[2]); + d = _mm_add_epi32(d, saves[3]); + + _mm_storeu_si128(output.as_mut_ptr().add(pos).cast(), a); + _mm_storeu_si128(output.as_mut_ptr().add(pos + 16).cast(), b); + _mm_storeu_si128(output.as_mut_ptr().add(pos + 32).cast(), c); + _mm_storeu_si128(output.as_mut_ptr().add(pos + 48).cast(), d); + } + } +} diff --git a/scrypt/src/lib.rs b/scrypt/src/lib.rs index 7c5fff80..dd426499 100644 --- a/scrypt/src/lib.rs +++ b/scrypt/src/lib.rs @@ -55,6 +55,7 @@ extern crate alloc; use pbkdf2::pbkdf2_hmac; use sha2::Sha256; +mod block_mix; /// Errors for `scrypt` operations. pub mod errors; mod params; diff --git a/scrypt/src/romix.rs b/scrypt/src/romix.rs index bd3c56dd..1be655dc 100644 --- a/scrypt/src/romix.rs +++ b/scrypt/src/romix.rs @@ -18,53 +18,21 @@ pub(crate) fn scrypt_ro_mix(b: &mut [u8], v: &mut [u8], t: &mut [u8], n: usize) let len = b.len(); + crate::block_mix::shuffle_in(b); + for chunk in v.chunks_mut(len) { chunk.copy_from_slice(b); - scrypt_block_mix(chunk, b); + crate::block_mix::scrypt_block_mix(chunk, b); } for _ in 0..n { let j = integerify(b, n); xor(b, &v[j * len..(j + 1) * len], t); - scrypt_block_mix(t, b); - } -} - -/// Execute the BlockMix operation -/// input - the input vector. The length must be a multiple of 128. -/// output - the output vector. Must be the same length as input. -fn scrypt_block_mix(input: &[u8], output: &mut [u8]) { - use salsa20::{ - SalsaCore, - cipher::{StreamCipherCore, typenum::U4}, - }; - - type Salsa20_8 = SalsaCore; - - let mut x = [0u8; 64]; - x.copy_from_slice(&input[input.len() - 64..]); - - let mut t = [0u8; 64]; - for (i, chunk) in input.chunks(64).enumerate() { - xor(&x, chunk, &mut t); - - let mut t2 = [0u32; 16]; - - for (c, b) in t.chunks_exact(4).zip(t2.iter_mut()) { - *b = u32::from_le_bytes(c.try_into().unwrap()); - } - - Salsa20_8::from_raw_state(t2).write_keystream_block((&mut x).into()); - - let pos = if i % 2 == 0 { - (i / 2) * 64 - } else { - (i / 2) * 64 + input.len() / 2 - }; - - output[pos..pos + 64].copy_from_slice(&x); + crate::block_mix::scrypt_block_mix(t, b); } + + crate::block_mix::shuffle_out(b); } fn xor(x: &[u8], y: &[u8], output: &mut [u8]) {