diff --git a/compiler/rustc_abi/src/lib.rs b/compiler/rustc_abi/src/lib.rs index 5bd73502d980a..8e346706877de 100644 --- a/compiler/rustc_abi/src/lib.rs +++ b/compiler/rustc_abi/src/lib.rs @@ -1376,6 +1376,28 @@ impl WrappingRange { } } + /// Returns `true` if all the values in `other` are contained in this range, + /// when the values are considered as having width `size`. + #[inline(always)] + pub fn contains_range(&self, other: Self, size: Size) -> bool { + if self.is_full_for(size) { + true + } else { + let trunc = |x| size.truncate(x); + + let delta = self.start; + let max = trunc(self.end.wrapping_sub(delta)); + + let other_start = trunc(other.start.wrapping_sub(delta)); + let other_end = trunc(other.end.wrapping_sub(delta)); + + // Having shifted both input ranges by `delta`, now we only need to check + // whether `0..=max` contains `other_start..=other_end`, which can only + // happen if the other doesn't wrap since `self` isn't everything. + (other_start <= other_end) && (other_end <= max) + } + } + /// Returns `self` with replaced `start` #[inline(always)] fn with_start(mut self, start: u128) -> Self { diff --git a/compiler/rustc_abi/src/tests.rs b/compiler/rustc_abi/src/tests.rs index d993012378c81..d49c2d44af84d 100644 --- a/compiler/rustc_abi/src/tests.rs +++ b/compiler/rustc_abi/src/tests.rs @@ -5,3 +5,66 @@ fn align_constants() { assert_eq!(Align::ONE, Align::from_bytes(1).unwrap()); assert_eq!(Align::EIGHT, Align::from_bytes(8).unwrap()); } + +#[test] +fn wrapping_range_contains_range() { + let size16 = Size::from_bytes(16); + + let a = WrappingRange { start: 10, end: 20 }; + assert!(a.contains_range(a, size16)); + assert!(a.contains_range(WrappingRange { start: 11, end: 19 }, size16)); + assert!(a.contains_range(WrappingRange { start: 10, end: 10 }, size16)); + assert!(a.contains_range(WrappingRange { start: 20, end: 20 }, size16)); + assert!(!a.contains_range(WrappingRange { start: 10, end: 21 }, size16)); + assert!(!a.contains_range(WrappingRange { start: 9, end: 20 }, size16)); + assert!(!a.contains_range(WrappingRange { start: 4, end: 6 }, size16)); + assert!(!a.contains_range(WrappingRange { start: 24, end: 26 }, size16)); + + assert!(!a.contains_range(WrappingRange { start: 16, end: 14 }, size16)); + + let b = WrappingRange { start: 20, end: 10 }; + assert!(b.contains_range(b, size16)); + assert!(b.contains_range(WrappingRange { start: 20, end: 20 }, size16)); + assert!(b.contains_range(WrappingRange { start: 10, end: 10 }, size16)); + assert!(b.contains_range(WrappingRange { start: 0, end: 10 }, size16)); + assert!(b.contains_range(WrappingRange { start: 20, end: 30 }, size16)); + assert!(b.contains_range(WrappingRange { start: 20, end: 9 }, size16)); + assert!(b.contains_range(WrappingRange { start: 21, end: 10 }, size16)); + assert!(b.contains_range(WrappingRange { start: 999, end: 9999 }, size16)); + assert!(b.contains_range(WrappingRange { start: 999, end: 9 }, size16)); + assert!(!b.contains_range(WrappingRange { start: 19, end: 19 }, size16)); + assert!(!b.contains_range(WrappingRange { start: 11, end: 11 }, size16)); + assert!(!b.contains_range(WrappingRange { start: 19, end: 11 }, size16)); + assert!(!b.contains_range(WrappingRange { start: 11, end: 19 }, size16)); + + let f = WrappingRange { start: 0, end: u128::MAX }; + assert!(f.contains_range(WrappingRange { start: 10, end: 20 }, size16)); + assert!(f.contains_range(WrappingRange { start: 20, end: 10 }, size16)); + + let g = WrappingRange { start: 2, end: 1 }; + assert!(g.contains_range(WrappingRange { start: 10, end: 20 }, size16)); + assert!(g.contains_range(WrappingRange { start: 20, end: 10 }, size16)); + + let size1 = Size::from_bytes(1); + let u8r = WrappingRange { start: 0, end: 255 }; + let i8r = WrappingRange { start: 128, end: 127 }; + assert!(u8r.contains_range(i8r, size1)); + assert!(i8r.contains_range(u8r, size1)); + assert!(!u8r.contains_range(i8r, size16)); + assert!(i8r.contains_range(u8r, size16)); + + let boolr = WrappingRange { start: 0, end: 1 }; + assert!(u8r.contains_range(boolr, size1)); + assert!(i8r.contains_range(boolr, size1)); + assert!(!boolr.contains_range(u8r, size1)); + assert!(!boolr.contains_range(i8r, size1)); + + let cmpr = WrappingRange { start: 255, end: 1 }; + assert!(u8r.contains_range(cmpr, size1)); + assert!(i8r.contains_range(cmpr, size1)); + assert!(!cmpr.contains_range(u8r, size1)); + assert!(!cmpr.contains_range(i8r, size1)); + + assert!(!boolr.contains_range(cmpr, size1)); + assert!(cmpr.contains_range(boolr, size1)); +} diff --git a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs index 610e2fd231117..a5759b79be45a 100644 --- a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs +++ b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs @@ -288,7 +288,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // valid ranges. For example, `char`s are passed as just `i32`, with no // way for LLVM to know that they're 0x10FFFF at most. Thus we assume // the range of the input value too, not just the output range. - assume_scalar_range(bx, imm, from_scalar, from_backend_ty); + assume_scalar_range(bx, imm, from_scalar, from_backend_ty, None); imm = match (from_scalar.primitive(), to_scalar.primitive()) { (Int(_, is_signed), Int(..)) => bx.intcast(imm, to_backend_ty, is_signed), @@ -1064,7 +1064,7 @@ pub(super) fn transmute_scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( // That said, last time we tried removing this, it didn't actually help // the rustc-perf results, so might as well keep doing it // - assume_scalar_range(bx, imm, from_scalar, from_backend_ty); + assume_scalar_range(bx, imm, from_scalar, from_backend_ty, Some(&to_scalar)); imm = match (from_scalar.primitive(), to_scalar.primitive()) { (Int(..) | Float(_), Int(..) | Float(_)) => bx.bitcast(imm, to_backend_ty), @@ -1092,22 +1092,42 @@ pub(super) fn transmute_scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( // since it's never passed to something with parameter metadata (especially // after MIR inlining) so the only way to tell the backend about the // constraint that the `transmute` introduced is to `assume` it. - assume_scalar_range(bx, imm, to_scalar, to_backend_ty); + assume_scalar_range(bx, imm, to_scalar, to_backend_ty, Some(&from_scalar)); imm = bx.to_immediate_scalar(imm, to_scalar); imm } +/// Emits an `assume` call that `imm`'s value is within the known range of `scalar`. +/// +/// If `known` is `Some`, only emits the assume if it's more specific than +/// whatever is already known from the range of *that* scalar. fn assume_scalar_range<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( bx: &mut Bx, imm: Bx::Value, scalar: abi::Scalar, backend_ty: Bx::Type, + known: Option<&abi::Scalar>, ) { - if matches!(bx.cx().sess().opts.optimize, OptLevel::No) || scalar.is_always_valid(bx.cx()) { + if matches!(bx.cx().sess().opts.optimize, OptLevel::No) { return; } + match (scalar, known) { + (abi::Scalar::Union { .. }, _) => return, + (_, None) => { + if scalar.is_always_valid(bx.cx()) { + return; + } + } + (abi::Scalar::Initialized { valid_range, .. }, Some(known)) => { + let known_range = known.valid_range(bx.cx()); + if valid_range.contains_range(known_range, scalar.size(bx.cx())) { + return; + } + } + } + match scalar.primitive() { abi::Primitive::Int(..) => { let range = scalar.valid_range(bx.cx()); diff --git a/tests/codegen-llvm/intrinsics/transmute-niched.rs b/tests/codegen-llvm/intrinsics/transmute-niched.rs index 8ff5cc8ee4f4c..a886d9eee5909 100644 --- a/tests/codegen-llvm/intrinsics/transmute-niched.rs +++ b/tests/codegen-llvm/intrinsics/transmute-niched.rs @@ -163,11 +163,8 @@ pub unsafe fn check_swap_pair(x: (char, NonZero)) -> (NonZero, char) { pub unsafe fn check_bool_from_ordering(x: std::cmp::Ordering) -> bool { // CHECK-NOT: icmp // CHECK-NOT: assume - // OPT: %0 = sub i8 %x, -1 - // OPT: %1 = icmp ule i8 %0, 2 - // OPT: call void @llvm.assume(i1 %1) - // OPT: %2 = icmp ule i8 %x, 1 - // OPT: call void @llvm.assume(i1 %2) + // OPT: %0 = icmp ule i8 %x, 1 + // OPT: call void @llvm.assume(i1 %0) // CHECK-NOT: icmp // CHECK-NOT: assume // CHECK: %[[R:.+]] = trunc{{( nuw)?}} i8 %x to i1 @@ -184,9 +181,6 @@ pub unsafe fn check_bool_to_ordering(x: bool) -> std::cmp::Ordering { // CHECK-NOT: assume // OPT: %0 = icmp ule i8 %_0, 1 // OPT: call void @llvm.assume(i1 %0) - // OPT: %1 = sub i8 %_0, -1 - // OPT: %2 = icmp ule i8 %1, 2 - // OPT: call void @llvm.assume(i1 %2) // CHECK-NOT: icmp // CHECK-NOT: assume // CHECK: ret i8 %_0 @@ -221,3 +215,42 @@ pub unsafe fn check_ptr_to_nonnull(x: *const u8) -> NonNull { transmute(x) } + +#[repr(usize)] +pub enum FourOrEight { + Four = 4, + Eight = 8, +} + +// CHECK-LABEL: @check_nonnull_to_four_or_eight( +#[no_mangle] +pub unsafe fn check_nonnull_to_four_or_eight(x: NonNull) -> FourOrEight { + // CHECK: start + // CHECK-NEXT: %[[RET:.+]] = ptrtoint ptr %x to i64 + // CHECK-NOT: icmp + // CHECK-NOT: assume + // OPT: %0 = sub i64 %[[RET]], 4 + // OPT: %1 = icmp ule i64 %0, 4 + // OPT: call void @llvm.assume(i1 %1) + // CHECK-NOT: icmp + // CHECK-NOT: assume + // CHECK: ret i64 %[[RET]] + + transmute(x) +} + +// CHECK-LABEL: @check_four_or_eight_to_nonnull( +#[no_mangle] +pub unsafe fn check_four_or_eight_to_nonnull(x: FourOrEight) -> NonNull { + // CHECK-NOT: icmp + // CHECK-NOT: assume + // OPT: %0 = sub i64 %x, 4 + // OPT: %1 = icmp ule i64 %0, 4 + // OPT: call void @llvm.assume(i1 %1) + // CHECK-NOT: icmp + // CHECK-NOT: assume + // CHECK: %[[RET:.+]] = getelementptr i8, ptr null, i64 %x + // CHECK-NEXT: ret ptr %[[RET]] + + transmute(x) +}