From 4565cb2cff0f2a474d1379e2cccddf7e5e65f3f1 Mon Sep 17 00:00:00 2001 From: sayantn Date: Fri, 8 Aug 2025 02:43:48 +0530 Subject: [PATCH] Implement arithmetic operation traits for x86 SIMD types --- crates/core_arch/src/macros.rs | 58 +++++++++++++++++++++++++++++ crates/core_arch/src/x86/avx2.rs | 14 ++----- crates/core_arch/src/x86/avx512f.rs | 40 ++++++++++---------- crates/core_arch/src/x86/mod.rs | 23 ++++++++++++ crates/core_arch/src/x86/sse2.rs | 8 ++-- 5 files changed, 109 insertions(+), 34 deletions(-) diff --git a/crates/core_arch/src/macros.rs b/crates/core_arch/src/macros.rs index e00b433536..0a43fd7088 100644 --- a/crates/core_arch/src/macros.rs +++ b/crates/core_arch/src/macros.rs @@ -163,3 +163,61 @@ macro_rules! simd_extract { ($x:expr, $idx:expr $(,)?) => {{ $crate::intrinsics::simd::simd_extract($x, const { $idx }) }}; ($x:expr, $idx:expr, $ty:ty $(,)?) => {{ $crate::intrinsics::simd::simd_extract::<_, $ty>($x, const { $idx }) }}; } + +#[allow(unused)] +macro_rules! impl_arith_op { + (__internal $op:ident, $intrinsic:ident $_:ident) => { + #[inline] + fn $op(self, rhs: Self) -> Self { + unsafe { crate::intrinsics::simd::$intrinsic(self, rhs) } + } + }; + (__internal $op:ident, $intrinsic:ident) => { + #[inline] + fn $op(self) -> Self { + unsafe { crate::intrinsics::simd::$intrinsic(self) } + } + }; + (: $($tt:tt)*) => {}; + ( + $type:ty $(, $other_types:ty )* : $( + $Trait:ident, $op:ident $(, $TraitAssign:ident, $op_assign:ident)? = $intrinsic:ident + );* $(;)? + ) => { + $( + #[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")] + impl crate::ops::$Trait for $type { + type Output = Self; + + impl_arith_op!(__internal $op, $intrinsic $( $TraitAssign )?); + } + + $( + #[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")] + impl crate::ops::$TraitAssign for $type { + #[inline] + fn $op_assign(&mut self, rhs: Self) { + *self = crate::ops::$Trait::$op(*self, rhs) + } + } + )? + )* + + impl_arith_op!($($other_types),* : $($Trait, $op $(, $TraitAssign, $op_assign)? = $intrinsic);*); + }; +} + +#[allow(unused)] +macro_rules! impl_not { + ($($type:ty),*) => {$( + #[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")] + impl crate::ops::Not for $type { + type Output = Self; + + #[inline] + fn not(self) -> Self { + unsafe { crate::intrinsics::simd::simd_xor(<$type>::splat(!0), self) } + } + } + )*}; +} diff --git a/crates/core_arch/src/x86/avx2.rs b/crates/core_arch/src/x86/avx2.rs index 739de2b341..2e489bb2e2 100644 --- a/crates/core_arch/src/x86/avx2.rs +++ b/crates/core_arch/src/x86/avx2.rs @@ -248,7 +248,7 @@ pub fn _mm256_alignr_epi8(a: __m256i, b: __m256i) -> __m256i { #[cfg_attr(test, assert_instr(vandps))] #[stable(feature = "simd_x86", since = "1.27.0")] pub fn _mm256_and_si256(a: __m256i, b: __m256i) -> __m256i { - unsafe { transmute(simd_and(a.as_i64x4(), b.as_i64x4())) } + a & b } /// Computes the bitwise NOT of 256 bits (representing integer data) @@ -260,13 +260,7 @@ pub fn _mm256_and_si256(a: __m256i, b: __m256i) -> __m256i { #[cfg_attr(test, assert_instr(vandnps))] #[stable(feature = "simd_x86", since = "1.27.0")] pub fn _mm256_andnot_si256(a: __m256i, b: __m256i) -> __m256i { - unsafe { - let all_ones = _mm256_set1_epi8(-1); - transmute(simd_and( - simd_xor(a.as_i64x4(), all_ones.as_i64x4()), - b.as_i64x4(), - )) - } + !a & b } /// Averages packed unsigned 16-bit integers in `a` and `b`. @@ -2184,7 +2178,7 @@ pub fn _mm256_mulhrs_epi16(a: __m256i, b: __m256i) -> __m256i { #[cfg_attr(test, assert_instr(vorps))] #[stable(feature = "simd_x86", since = "1.27.0")] pub fn _mm256_or_si256(a: __m256i, b: __m256i) -> __m256i { - unsafe { transmute(simd_or(a.as_i32x8(), b.as_i32x8())) } + a | b } /// Converts packed 16-bit integers from `a` and `b` to packed 8-bit integers @@ -3557,7 +3551,7 @@ pub fn _mm256_unpacklo_epi64(a: __m256i, b: __m256i) -> __m256i { #[cfg_attr(test, assert_instr(vxorps))] #[stable(feature = "simd_x86", since = "1.27.0")] pub fn _mm256_xor_si256(a: __m256i, b: __m256i) -> __m256i { - unsafe { transmute(simd_xor(a.as_i64x4(), b.as_i64x4())) } + a ^ b } /// Extracts an 8-bit integer from `a`, selected with `INDEX`. Returns a 32-bit diff --git a/crates/core_arch/src/x86/avx512f.rs b/crates/core_arch/src/x86/avx512f.rs index d53f83c0a1..583aa2e849 100644 --- a/crates/core_arch/src/x86/avx512f.rs +++ b/crates/core_arch/src/x86/avx512f.rs @@ -28149,7 +28149,7 @@ pub fn _mm_maskz_alignr_epi64(k: __mmask8, a: __m128i, b: __m12 #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpandq))] //should be vpandd, but generate vpandq pub fn _mm512_and_epi32(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_and(a.as_i32x16(), b.as_i32x16())) } + a & b } /// Performs element-by-element bitwise AND between packed 32-bit integer elements of a and b, storing the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28244,7 +28244,7 @@ pub fn _mm_maskz_and_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpandq))] pub fn _mm512_and_epi64(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_and(a.as_i64x8(), b.as_i64x8())) } + a & b } /// Compute the bitwise AND of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28339,7 +28339,7 @@ pub fn _mm_maskz_and_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpandq))] pub fn _mm512_and_si512(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_and(a.as_i32x16(), b.as_i32x16())) } + a & b } /// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst. @@ -28350,7 +28350,7 @@ pub fn _mm512_and_si512(a: __m512i, b: __m512i) -> __m512i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vporq))] pub fn _mm512_or_epi32(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_or(a.as_i32x16(), b.as_i32x16())) } + a | b } /// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28389,7 +28389,7 @@ pub fn _mm512_maskz_or_epi32(k: __mmask16, a: __m512i, b: __m512i) -> __m512i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vor))] //should be vpord pub fn _mm256_or_epi32(a: __m256i, b: __m256i) -> __m256i { - unsafe { transmute(simd_or(a.as_i32x8(), b.as_i32x8())) } + a | b } /// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28428,7 +28428,7 @@ pub fn _mm256_maskz_or_epi32(k: __mmask8, a: __m256i, b: __m256i) -> __m256i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vor))] //should be vpord pub fn _mm_or_epi32(a: __m128i, b: __m128i) -> __m128i { - unsafe { transmute(simd_or(a.as_i32x4(), b.as_i32x4())) } + a | b } /// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28467,7 +28467,7 @@ pub fn _mm_maskz_or_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vporq))] pub fn _mm512_or_epi64(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_or(a.as_i64x8(), b.as_i64x8())) } + a | b } /// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28506,7 +28506,7 @@ pub fn _mm512_maskz_or_epi64(k: __mmask8, a: __m512i, b: __m512i) -> __m512i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vor))] //should be vporq pub fn _mm256_or_epi64(a: __m256i, b: __m256i) -> __m256i { - unsafe { transmute(simd_or(a.as_i64x4(), b.as_i64x4())) } + a | b } /// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28545,7 +28545,7 @@ pub fn _mm256_maskz_or_epi64(k: __mmask8, a: __m256i, b: __m256i) -> __m256i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vor))] //should be vporq pub fn _mm_or_epi64(a: __m128i, b: __m128i) -> __m128i { - unsafe { transmute(simd_or(a.as_i64x2(), b.as_i64x2())) } + a | b } /// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28584,7 +28584,7 @@ pub fn _mm_maskz_or_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vporq))] pub fn _mm512_or_si512(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_or(a.as_i32x16(), b.as_i32x16())) } + a | b } /// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst. @@ -28595,7 +28595,7 @@ pub fn _mm512_or_si512(a: __m512i, b: __m512i) -> __m512i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpxorq))] //should be vpxord pub fn _mm512_xor_epi32(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_xor(a.as_i32x16(), b.as_i32x16())) } + a ^ b } /// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28634,7 +28634,7 @@ pub fn _mm512_maskz_xor_epi32(k: __mmask16, a: __m512i, b: __m512i) -> __m512i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vxor))] //should be vpxord pub fn _mm256_xor_epi32(a: __m256i, b: __m256i) -> __m256i { - unsafe { transmute(simd_xor(a.as_i32x8(), b.as_i32x8())) } + a ^ b } /// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28673,7 +28673,7 @@ pub fn _mm256_maskz_xor_epi32(k: __mmask8, a: __m256i, b: __m256i) -> __m256i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vxor))] //should be vpxord pub fn _mm_xor_epi32(a: __m128i, b: __m128i) -> __m128i { - unsafe { transmute(simd_xor(a.as_i32x4(), b.as_i32x4())) } + a ^ b } /// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28712,7 +28712,7 @@ pub fn _mm_maskz_xor_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpxorq))] pub fn _mm512_xor_epi64(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_xor(a.as_i64x8(), b.as_i64x8())) } + a ^ b } /// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28751,7 +28751,7 @@ pub fn _mm512_maskz_xor_epi64(k: __mmask8, a: __m512i, b: __m512i) -> __m512i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vxor))] //should be vpxorq pub fn _mm256_xor_epi64(a: __m256i, b: __m256i) -> __m256i { - unsafe { transmute(simd_xor(a.as_i64x4(), b.as_i64x4())) } + a ^ b } /// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28790,7 +28790,7 @@ pub fn _mm256_maskz_xor_epi64(k: __mmask8, a: __m256i, b: __m256i) -> __m256i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vxor))] //should be vpxorq pub fn _mm_xor_epi64(a: __m128i, b: __m128i) -> __m128i { - unsafe { transmute(simd_xor(a.as_i64x2(), b.as_i64x2())) } + a ^ b } /// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28829,7 +28829,7 @@ pub fn _mm_maskz_xor_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpxorq))] pub fn _mm512_xor_si512(a: __m512i, b: __m512i) -> __m512i { - unsafe { transmute(simd_xor(a.as_i32x16(), b.as_i32x16())) } + a ^ b } /// Compute the bitwise NOT of packed 32-bit integers in a and then AND with b, and store the results in dst. @@ -28840,7 +28840,7 @@ pub fn _mm512_xor_si512(a: __m512i, b: __m512i) -> __m512i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpandnq))] //should be vpandnd pub fn _mm512_andnot_epi32(a: __m512i, b: __m512i) -> __m512i { - _mm512_and_epi32(_mm512_xor_epi32(a, _mm512_set1_epi32(u32::MAX as i32)), b) + !a & b } /// Compute the bitwise NOT of packed 32-bit integers in a and then AND with b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -28939,7 +28939,7 @@ pub fn _mm_maskz_andnot_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpandnq))] //should be vpandnd pub fn _mm512_andnot_epi64(a: __m512i, b: __m512i) -> __m512i { - _mm512_and_epi64(_mm512_xor_epi64(a, _mm512_set1_epi64(u64::MAX as i64)), b) + !a & b } /// Compute the bitwise NOT of packed 64-bit integers in a and then AND with b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -29038,7 +29038,7 @@ pub fn _mm_maskz_andnot_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i { #[stable(feature = "stdarch_x86_avx512", since = "1.89")] #[cfg_attr(test, assert_instr(vpandnq))] pub fn _mm512_andnot_si512(a: __m512i, b: __m512i) -> __m512i { - _mm512_and_epi64(_mm512_xor_epi64(a, _mm512_set1_epi64(u64::MAX as i64)), b) + !a & b } /// Convert 16-bit mask a into an integer value, and store the result in dst. diff --git a/crates/core_arch/src/x86/mod.rs b/crates/core_arch/src/x86/mod.rs index 79a593e647..c479a92711 100644 --- a/crates/core_arch/src/x86/mod.rs +++ b/crates/core_arch/src/x86/mod.rs @@ -471,6 +471,29 @@ impl bf16 { } } +impl_arith_op!( + __m128, __m128d, __m128h, + __m256, __m256d, __m256h, + __m512, __m512d, __m512h: + Add, add, AddAssign, add_assign = simd_add; + Sub, sub, SubAssign, sub_assign = simd_sub; + Mul, mul, MulAssign, mul_assign = simd_mul; + Div, div, DivAssign, div_assign = simd_div; + Rem, rem, RemAssign, rem_assign = simd_rem; + Neg, neg = simd_neg; +); + +impl_arith_op!( + __m128i, __m256i, __m512i: + BitOr, bitor, BitOrAssign, bitor_assign = simd_or; + BitAnd, bitand, BitAndAssign, bitand_assign = simd_and; + BitXor, bitxor, BitXorAssign, bitxor_assign = simd_xor; +); + +impl_not!(__m128i, __m256i, __m512i); + +// TODO: should we have `Rem` and `Not`? + /// The `__mmask64` type used in AVX-512 intrinsics, a 64-bit integer #[allow(non_camel_case_types)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] diff --git a/crates/core_arch/src/x86/sse2.rs b/crates/core_arch/src/x86/sse2.rs index 1eaa89663b..ad49b08ef8 100644 --- a/crates/core_arch/src/x86/sse2.rs +++ b/crates/core_arch/src/x86/sse2.rs @@ -823,7 +823,7 @@ pub fn _mm_srl_epi64(a: __m128i, count: __m128i) -> __m128i { #[cfg_attr(test, assert_instr(andps))] #[stable(feature = "simd_x86", since = "1.27.0")] pub fn _mm_and_si128(a: __m128i, b: __m128i) -> __m128i { - unsafe { simd_and(a, b) } + a & b } /// Computes the bitwise NOT of 128 bits (representing integer data) in `a` and @@ -835,7 +835,7 @@ pub fn _mm_and_si128(a: __m128i, b: __m128i) -> __m128i { #[cfg_attr(test, assert_instr(andnps))] #[stable(feature = "simd_x86", since = "1.27.0")] pub fn _mm_andnot_si128(a: __m128i, b: __m128i) -> __m128i { - unsafe { simd_and(simd_xor(_mm_set1_epi8(-1), a), b) } + !a & b } /// Computes the bitwise OR of 128 bits (representing integer data) in `a` and @@ -847,7 +847,7 @@ pub fn _mm_andnot_si128(a: __m128i, b: __m128i) -> __m128i { #[cfg_attr(test, assert_instr(orps))] #[stable(feature = "simd_x86", since = "1.27.0")] pub fn _mm_or_si128(a: __m128i, b: __m128i) -> __m128i { - unsafe { simd_or(a, b) } + a | b } /// Computes the bitwise XOR of 128 bits (representing integer data) in `a` and @@ -859,7 +859,7 @@ pub fn _mm_or_si128(a: __m128i, b: __m128i) -> __m128i { #[cfg_attr(test, assert_instr(xorps))] #[stable(feature = "simd_x86", since = "1.27.0")] pub fn _mm_xor_si128(a: __m128i, b: __m128i) -> __m128i { - unsafe { simd_xor(a, b) } + a ^ b } /// Compares packed 8-bit integers in `a` and `b` for equality.