Skip to content

Implement arithmetic operation traits for x86 SIMD types #1898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions crates/core_arch/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,61 @@ macro_rules! simd_extract {
($x:expr, $idx:expr $(,)?) => {{ $crate::intrinsics::simd::simd_extract($x, const { $idx }) }};
($x:expr, $idx:expr, $ty:ty $(,)?) => {{ $crate::intrinsics::simd::simd_extract::<_, $ty>($x, const { $idx }) }};
}

#[allow(unused)]
macro_rules! impl_arith_op {
(__internal $op:ident, $intrinsic:ident $_:ident) => {
#[inline]
fn $op(self, rhs: Self) -> Self {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are missing a #[target_feature(enable = "...")] matching the _mm* caller, so the MIR inliner will not inline them and LLVM may not inline them either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simd_* intrinsics call the platform-independent LLVM intrinsics/instructions. The conversion to ASM checks the available feature flags, and uses them as available. See rust-lang/libs-team#628 (comment)

unsafe { crate::intrinsics::simd::$intrinsic(self, rhs) }
}
};
(__internal $op:ident, $intrinsic:ident) => {
#[inline]
fn $op(self) -> Self {
unsafe { crate::intrinsics::simd::$intrinsic(self) }
}
};
(: $($tt:tt)*) => {};
(
$type:ty $(, $other_types:ty )* : $(
$Trait:ident, $op:ident $(, $TraitAssign:ident, $op_assign:ident)? = $intrinsic:ident
);* $(;)?
) => {
$(
#[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")]
impl crate::ops::$Trait for $type {
type Output = Self;

impl_arith_op!(__internal $op, $intrinsic $( $TraitAssign )?);
}

$(
#[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")]
impl crate::ops::$TraitAssign for $type {
#[inline]
fn $op_assign(&mut self, rhs: Self) {
*self = crate::ops::$Trait::$op(*self, rhs)
}
}
)?
)*

impl_arith_op!($($other_types),* : $($Trait, $op $(, $TraitAssign, $op_assign)? = $intrinsic);*);
};
}

#[allow(unused)]
macro_rules! impl_not {
($($type:ty),*) => {$(
#[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")]
impl crate::ops::Not for $type {
type Output = Self;

#[inline]
fn not(self) -> Self {
unsafe { crate::intrinsics::simd::simd_xor(<$type>::splat(!0), self) }
}
}
)*};
}
14 changes: 4 additions & 10 deletions crates/core_arch/src/x86/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ pub fn _mm256_alignr_epi8<const IMM8: i32>(a: __m256i, b: __m256i) -> __m256i {
#[cfg_attr(test, assert_instr(vandps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_and_si256(a: __m256i, b: __m256i) -> __m256i {
unsafe { transmute(simd_and(a.as_i64x4(), b.as_i64x4())) }
a & b
}

/// Computes the bitwise NOT of 256 bits (representing integer data)
Expand All @@ -260,13 +260,7 @@ pub fn _mm256_and_si256(a: __m256i, b: __m256i) -> __m256i {
#[cfg_attr(test, assert_instr(vandnps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_andnot_si256(a: __m256i, b: __m256i) -> __m256i {
unsafe {
let all_ones = _mm256_set1_epi8(-1);
transmute(simd_and(
simd_xor(a.as_i64x4(), all_ones.as_i64x4()),
b.as_i64x4(),
))
}
!a & b
}

/// Averages packed unsigned 16-bit integers in `a` and `b`.
Expand Down Expand Up @@ -2184,7 +2178,7 @@ pub fn _mm256_mulhrs_epi16(a: __m256i, b: __m256i) -> __m256i {
#[cfg_attr(test, assert_instr(vorps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_or_si256(a: __m256i, b: __m256i) -> __m256i {
unsafe { transmute(simd_or(a.as_i32x8(), b.as_i32x8())) }
a | b
}

/// Converts packed 16-bit integers from `a` and `b` to packed 8-bit integers
Expand Down Expand Up @@ -3557,7 +3551,7 @@ pub fn _mm256_unpacklo_epi64(a: __m256i, b: __m256i) -> __m256i {
#[cfg_attr(test, assert_instr(vxorps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_xor_si256(a: __m256i, b: __m256i) -> __m256i {
unsafe { transmute(simd_xor(a.as_i64x4(), b.as_i64x4())) }
a ^ b
}

/// Extracts an 8-bit integer from `a`, selected with `INDEX`. Returns a 32-bit
Expand Down
40 changes: 20 additions & 20 deletions crates/core_arch/src/x86/avx512f.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28149,7 +28149,7 @@ pub fn _mm_maskz_alignr_epi64<const IMM8: i32>(k: __mmask8, a: __m128i, b: __m12
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpandq))] //should be vpandd, but generate vpandq
pub fn _mm512_and_epi32(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_and(a.as_i32x16(), b.as_i32x16())) }
a & b
}

/// Performs element-by-element bitwise AND between packed 32-bit integer elements of a and b, storing the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28244,7 +28244,7 @@ pub fn _mm_maskz_and_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpandq))]
pub fn _mm512_and_epi64(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_and(a.as_i64x8(), b.as_i64x8())) }
a & b
}

/// Compute the bitwise AND of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28339,7 +28339,7 @@ pub fn _mm_maskz_and_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpandq))]
pub fn _mm512_and_si512(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_and(a.as_i32x16(), b.as_i32x16())) }
a & b
}

/// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst.
Expand All @@ -28350,7 +28350,7 @@ pub fn _mm512_and_si512(a: __m512i, b: __m512i) -> __m512i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vporq))]
pub fn _mm512_or_epi32(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_or(a.as_i32x16(), b.as_i32x16())) }
a | b
}

/// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28389,7 +28389,7 @@ pub fn _mm512_maskz_or_epi32(k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vor))] //should be vpord
pub fn _mm256_or_epi32(a: __m256i, b: __m256i) -> __m256i {
unsafe { transmute(simd_or(a.as_i32x8(), b.as_i32x8())) }
a | b
}

/// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28428,7 +28428,7 @@ pub fn _mm256_maskz_or_epi32(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vor))] //should be vpord
pub fn _mm_or_epi32(a: __m128i, b: __m128i) -> __m128i {
unsafe { transmute(simd_or(a.as_i32x4(), b.as_i32x4())) }
a | b
}

/// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28467,7 +28467,7 @@ pub fn _mm_maskz_or_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vporq))]
pub fn _mm512_or_epi64(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_or(a.as_i64x8(), b.as_i64x8())) }
a | b
}

/// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28506,7 +28506,7 @@ pub fn _mm512_maskz_or_epi64(k: __mmask8, a: __m512i, b: __m512i) -> __m512i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vor))] //should be vporq
pub fn _mm256_or_epi64(a: __m256i, b: __m256i) -> __m256i {
unsafe { transmute(simd_or(a.as_i64x4(), b.as_i64x4())) }
a | b
}

/// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28545,7 +28545,7 @@ pub fn _mm256_maskz_or_epi64(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vor))] //should be vporq
pub fn _mm_or_epi64(a: __m128i, b: __m128i) -> __m128i {
unsafe { transmute(simd_or(a.as_i64x2(), b.as_i64x2())) }
a | b
}

/// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28584,7 +28584,7 @@ pub fn _mm_maskz_or_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vporq))]
pub fn _mm512_or_si512(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_or(a.as_i32x16(), b.as_i32x16())) }
a | b
}

/// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst.
Expand All @@ -28595,7 +28595,7 @@ pub fn _mm512_or_si512(a: __m512i, b: __m512i) -> __m512i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpxorq))] //should be vpxord
pub fn _mm512_xor_epi32(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_xor(a.as_i32x16(), b.as_i32x16())) }
a ^ b
}

/// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28634,7 +28634,7 @@ pub fn _mm512_maskz_xor_epi32(k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vxor))] //should be vpxord
pub fn _mm256_xor_epi32(a: __m256i, b: __m256i) -> __m256i {
unsafe { transmute(simd_xor(a.as_i32x8(), b.as_i32x8())) }
a ^ b
}

/// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28673,7 +28673,7 @@ pub fn _mm256_maskz_xor_epi32(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vxor))] //should be vpxord
pub fn _mm_xor_epi32(a: __m128i, b: __m128i) -> __m128i {
unsafe { transmute(simd_xor(a.as_i32x4(), b.as_i32x4())) }
a ^ b
}

/// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28712,7 +28712,7 @@ pub fn _mm_maskz_xor_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpxorq))]
pub fn _mm512_xor_epi64(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_xor(a.as_i64x8(), b.as_i64x8())) }
a ^ b
}

/// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28751,7 +28751,7 @@ pub fn _mm512_maskz_xor_epi64(k: __mmask8, a: __m512i, b: __m512i) -> __m512i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vxor))] //should be vpxorq
pub fn _mm256_xor_epi64(a: __m256i, b: __m256i) -> __m256i {
unsafe { transmute(simd_xor(a.as_i64x4(), b.as_i64x4())) }
a ^ b
}

/// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28790,7 +28790,7 @@ pub fn _mm256_maskz_xor_epi64(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vxor))] //should be vpxorq
pub fn _mm_xor_epi64(a: __m128i, b: __m128i) -> __m128i {
unsafe { transmute(simd_xor(a.as_i64x2(), b.as_i64x2())) }
a ^ b
}

/// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28829,7 +28829,7 @@ pub fn _mm_maskz_xor_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpxorq))]
pub fn _mm512_xor_si512(a: __m512i, b: __m512i) -> __m512i {
unsafe { transmute(simd_xor(a.as_i32x16(), b.as_i32x16())) }
a ^ b
}

/// Compute the bitwise NOT of packed 32-bit integers in a and then AND with b, and store the results in dst.
Expand All @@ -28840,7 +28840,7 @@ pub fn _mm512_xor_si512(a: __m512i, b: __m512i) -> __m512i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpandnq))] //should be vpandnd
pub fn _mm512_andnot_epi32(a: __m512i, b: __m512i) -> __m512i {
_mm512_and_epi32(_mm512_xor_epi32(a, _mm512_set1_epi32(u32::MAX as i32)), b)
!a & b
}

/// Compute the bitwise NOT of packed 32-bit integers in a and then AND with b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -28939,7 +28939,7 @@ pub fn _mm_maskz_andnot_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpandnq))] //should be vpandnd
pub fn _mm512_andnot_epi64(a: __m512i, b: __m512i) -> __m512i {
_mm512_and_epi64(_mm512_xor_epi64(a, _mm512_set1_epi64(u64::MAX as i64)), b)
!a & b
}

/// Compute the bitwise NOT of packed 64-bit integers in a and then AND with b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand Down Expand Up @@ -29038,7 +29038,7 @@ pub fn _mm_maskz_andnot_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpandnq))]
pub fn _mm512_andnot_si512(a: __m512i, b: __m512i) -> __m512i {
_mm512_and_epi64(_mm512_xor_epi64(a, _mm512_set1_epi64(u64::MAX as i64)), b)
!a & b
}

/// Convert 16-bit mask a into an integer value, and store the result in dst.
Expand Down
23 changes: 23 additions & 0 deletions crates/core_arch/src/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,29 @@ impl bf16 {
}
}

impl_arith_op!(
__m128, __m128d, __m128h,
__m256, __m256d, __m256h,
__m512, __m512d, __m512h:
Add, add, AddAssign, add_assign = simd_add;
Sub, sub, SubAssign, sub_assign = simd_sub;
Mul, mul, MulAssign, mul_assign = simd_mul;
Div, div, DivAssign, div_assign = simd_div;
Rem, rem, RemAssign, rem_assign = simd_rem;
Neg, neg = simd_neg;
);

impl_arith_op!(
__m128i, __m256i, __m512i:
BitOr, bitor, BitOrAssign, bitor_assign = simd_or;
BitAnd, bitand, BitAndAssign, bitand_assign = simd_and;
BitXor, bitxor, BitXorAssign, bitxor_assign = simd_xor;
);

impl_not!(__m128i, __m256i, __m512i);

// TODO: should we have `Rem` and `Not`?

/// The `__mmask64` type used in AVX-512 intrinsics, a 64-bit integer
#[allow(non_camel_case_types)]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
Expand Down
8 changes: 4 additions & 4 deletions crates/core_arch/src/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ pub fn _mm_srl_epi64(a: __m128i, count: __m128i) -> __m128i {
#[cfg_attr(test, assert_instr(andps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm_and_si128(a: __m128i, b: __m128i) -> __m128i {
unsafe { simd_and(a, b) }
a & b
}

/// Computes the bitwise NOT of 128 bits (representing integer data) in `a` and
Expand All @@ -835,7 +835,7 @@ pub fn _mm_and_si128(a: __m128i, b: __m128i) -> __m128i {
#[cfg_attr(test, assert_instr(andnps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm_andnot_si128(a: __m128i, b: __m128i) -> __m128i {
unsafe { simd_and(simd_xor(_mm_set1_epi8(-1), a), b) }
!a & b
}

/// Computes the bitwise OR of 128 bits (representing integer data) in `a` and
Expand All @@ -847,7 +847,7 @@ pub fn _mm_andnot_si128(a: __m128i, b: __m128i) -> __m128i {
#[cfg_attr(test, assert_instr(orps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm_or_si128(a: __m128i, b: __m128i) -> __m128i {
unsafe { simd_or(a, b) }
a | b
}

/// Computes the bitwise XOR of 128 bits (representing integer data) in `a` and
Expand All @@ -859,7 +859,7 @@ pub fn _mm_or_si128(a: __m128i, b: __m128i) -> __m128i {
#[cfg_attr(test, assert_instr(xorps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm_xor_si128(a: __m128i, b: __m128i) -> __m128i {
unsafe { simd_xor(a, b) }
a ^ b
}

/// Compares packed 8-bit integers in `a` and `b` for equality.
Expand Down
Loading