Skip to content

Don't emit two assumes in transmutes when one is a subset of the other #144209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,28 @@ impl WrappingRange {
}
}

/// Returns `true` if all the values in `other` are contained in this range,
/// when the values are considered as having width `size`.
#[inline(always)]
pub fn contains_range(&self, other: Self, size: Size) -> bool {
if self.is_full_for(size) {
true
} else {
let trunc = |x| size.truncate(x);

let delta = self.start;
let max = trunc(self.end.wrapping_sub(delta));

let other_start = trunc(other.start.wrapping_sub(delta));
let other_end = trunc(other.end.wrapping_sub(delta));

// Having shifted both input ranges by `delta`, now we only need to check
// whether `0..=max` contains `other_start..=other_end`, which can only
// happen if the other doesn't wrap since `self` isn't everything.
(other_start <= other_end) && (other_end <= max)
}
}

/// Returns `self` with replaced `start`
#[inline(always)]
fn with_start(mut self, start: u128) -> Self {
Expand Down
63 changes: 63 additions & 0 deletions compiler/rustc_abi/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,66 @@ fn align_constants() {
assert_eq!(Align::ONE, Align::from_bytes(1).unwrap());
assert_eq!(Align::EIGHT, Align::from_bytes(8).unwrap());
}

#[test]
fn wrapping_range_contains_range() {
let size16 = Size::from_bytes(16);

let a = WrappingRange { start: 10, end: 20 };
assert!(a.contains_range(a, size16));
assert!(a.contains_range(WrappingRange { start: 11, end: 19 }, size16));
assert!(a.contains_range(WrappingRange { start: 10, end: 10 }, size16));
assert!(a.contains_range(WrappingRange { start: 20, end: 20 }, size16));
assert!(!a.contains_range(WrappingRange { start: 10, end: 21 }, size16));
assert!(!a.contains_range(WrappingRange { start: 9, end: 20 }, size16));
assert!(!a.contains_range(WrappingRange { start: 4, end: 6 }, size16));
assert!(!a.contains_range(WrappingRange { start: 24, end: 26 }, size16));

assert!(!a.contains_range(WrappingRange { start: 16, end: 14 }, size16));

let b = WrappingRange { start: 20, end: 10 };
assert!(b.contains_range(b, size16));
assert!(b.contains_range(WrappingRange { start: 20, end: 20 }, size16));
assert!(b.contains_range(WrappingRange { start: 10, end: 10 }, size16));
assert!(b.contains_range(WrappingRange { start: 0, end: 10 }, size16));
assert!(b.contains_range(WrappingRange { start: 20, end: 30 }, size16));
assert!(b.contains_range(WrappingRange { start: 20, end: 9 }, size16));
assert!(b.contains_range(WrappingRange { start: 21, end: 10 }, size16));
assert!(b.contains_range(WrappingRange { start: 999, end: 9999 }, size16));
assert!(b.contains_range(WrappingRange { start: 999, end: 9 }, size16));
assert!(!b.contains_range(WrappingRange { start: 19, end: 19 }, size16));
assert!(!b.contains_range(WrappingRange { start: 11, end: 11 }, size16));
assert!(!b.contains_range(WrappingRange { start: 19, end: 11 }, size16));
assert!(!b.contains_range(WrappingRange { start: 11, end: 19 }, size16));

let f = WrappingRange { start: 0, end: u128::MAX };
assert!(f.contains_range(WrappingRange { start: 10, end: 20 }, size16));
assert!(f.contains_range(WrappingRange { start: 20, end: 10 }, size16));

let g = WrappingRange { start: 2, end: 1 };
assert!(g.contains_range(WrappingRange { start: 10, end: 20 }, size16));
assert!(g.contains_range(WrappingRange { start: 20, end: 10 }, size16));

let size1 = Size::from_bytes(1);
let u8r = WrappingRange { start: 0, end: 255 };
let i8r = WrappingRange { start: 128, end: 127 };
assert!(u8r.contains_range(i8r, size1));
assert!(i8r.contains_range(u8r, size1));
assert!(!u8r.contains_range(i8r, size16));
assert!(i8r.contains_range(u8r, size16));

let boolr = WrappingRange { start: 0, end: 1 };
assert!(u8r.contains_range(boolr, size1));
assert!(i8r.contains_range(boolr, size1));
assert!(!boolr.contains_range(u8r, size1));
assert!(!boolr.contains_range(i8r, size1));

let cmpr = WrappingRange { start: 255, end: 1 };
assert!(u8r.contains_range(cmpr, size1));
assert!(i8r.contains_range(cmpr, size1));
assert!(!cmpr.contains_range(u8r, size1));
assert!(!cmpr.contains_range(i8r, size1));

assert!(!boolr.contains_range(cmpr, size1));
assert!(cmpr.contains_range(boolr, size1));
}
28 changes: 24 additions & 4 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
// valid ranges. For example, `char`s are passed as just `i32`, with no
// way for LLVM to know that they're 0x10FFFF at most. Thus we assume
// the range of the input value too, not just the output range.
assume_scalar_range(bx, imm, from_scalar, from_backend_ty);
assume_scalar_range(bx, imm, from_scalar, from_backend_ty, None);

imm = match (from_scalar.primitive(), to_scalar.primitive()) {
(Int(_, is_signed), Int(..)) => bx.intcast(imm, to_backend_ty, is_signed),
Expand Down Expand Up @@ -1064,7 +1064,7 @@ pub(super) fn transmute_scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
// That said, last time we tried removing this, it didn't actually help
// the rustc-perf results, so might as well keep doing it
// <https://github.com/rust-lang/rust/pull/135610#issuecomment-2599275182>
assume_scalar_range(bx, imm, from_scalar, from_backend_ty);
assume_scalar_range(bx, imm, from_scalar, from_backend_ty, Some(&to_scalar));

imm = match (from_scalar.primitive(), to_scalar.primitive()) {
(Int(..) | Float(_), Int(..) | Float(_)) => bx.bitcast(imm, to_backend_ty),
Expand Down Expand Up @@ -1092,22 +1092,42 @@ pub(super) fn transmute_scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
// since it's never passed to something with parameter metadata (especially
// after MIR inlining) so the only way to tell the backend about the
// constraint that the `transmute` introduced is to `assume` it.
assume_scalar_range(bx, imm, to_scalar, to_backend_ty);
assume_scalar_range(bx, imm, to_scalar, to_backend_ty, Some(&from_scalar));

imm = bx.to_immediate_scalar(imm, to_scalar);
imm
}

/// Emits an `assume` call that `imm`'s value is within the known range of `scalar`.
///
/// If `known` is `Some`, only emits the assume if it's more specific than
/// whatever is already known from the range of *that* scalar.
fn assume_scalar_range<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
bx: &mut Bx,
imm: Bx::Value,
scalar: abi::Scalar,
backend_ty: Bx::Type,
known: Option<&abi::Scalar>,
) {
if matches!(bx.cx().sess().opts.optimize, OptLevel::No) || scalar.is_always_valid(bx.cx()) {
if matches!(bx.cx().sess().opts.optimize, OptLevel::No) {
return;
}

match (scalar, known) {
(abi::Scalar::Union { .. }, _) => return,
(_, None) => {
if scalar.is_always_valid(bx.cx()) {
return;
}
}
(abi::Scalar::Initialized { valid_range, .. }, Some(known)) => {
let known_range = known.valid_range(bx.cx());
if valid_range.contains_range(known_range, scalar.size(bx.cx())) {
return;
}
}
}

match scalar.primitive() {
abi::Primitive::Int(..) => {
let range = scalar.valid_range(bx.cx());
Expand Down
49 changes: 41 additions & 8 deletions tests/codegen-llvm/intrinsics/transmute-niched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,8 @@ pub unsafe fn check_swap_pair(x: (char, NonZero<u32>)) -> (NonZero<u32>, char) {
pub unsafe fn check_bool_from_ordering(x: std::cmp::Ordering) -> bool {
// CHECK-NOT: icmp
// CHECK-NOT: assume
// OPT: %0 = sub i8 %x, -1
// OPT: %1 = icmp ule i8 %0, 2
// OPT: call void @llvm.assume(i1 %1)
// OPT: %2 = icmp ule i8 %x, 1
// OPT: call void @llvm.assume(i1 %2)
// OPT: %0 = icmp ule i8 %x, 1
// OPT: call void @llvm.assume(i1 %0)
// CHECK-NOT: icmp
// CHECK-NOT: assume
// CHECK: %[[R:.+]] = trunc{{( nuw)?}} i8 %x to i1
Expand All @@ -184,9 +181,6 @@ pub unsafe fn check_bool_to_ordering(x: bool) -> std::cmp::Ordering {
// CHECK-NOT: assume
// OPT: %0 = icmp ule i8 %_0, 1
// OPT: call void @llvm.assume(i1 %0)
// OPT: %1 = sub i8 %_0, -1
// OPT: %2 = icmp ule i8 %1, 2
// OPT: call void @llvm.assume(i1 %2)
// CHECK-NOT: icmp
// CHECK-NOT: assume
// CHECK: ret i8 %_0
Expand Down Expand Up @@ -221,3 +215,42 @@ pub unsafe fn check_ptr_to_nonnull(x: *const u8) -> NonNull<u8> {

transmute(x)
}

#[repr(usize)]
pub enum FourOrEight {
Four = 4,
Eight = 8,
}

// CHECK-LABEL: @check_nonnull_to_four_or_eight(
#[no_mangle]
pub unsafe fn check_nonnull_to_four_or_eight(x: NonNull<u8>) -> FourOrEight {
// CHECK: start
// CHECK-NEXT: %[[RET:.+]] = ptrtoint ptr %x to i64
// CHECK-NOT: icmp
// CHECK-NOT: assume
// OPT: %0 = sub i64 %[[RET]], 4
// OPT: %1 = icmp ule i64 %0, 4
// OPT: call void @llvm.assume(i1 %1)
// CHECK-NOT: icmp
// CHECK-NOT: assume
// CHECK: ret i64 %[[RET]]

transmute(x)
}

// CHECK-LABEL: @check_four_or_eight_to_nonnull(
#[no_mangle]
pub unsafe fn check_four_or_eight_to_nonnull(x: FourOrEight) -> NonNull<u8> {
// CHECK-NOT: icmp
// CHECK-NOT: assume
// OPT: %0 = sub i64 %x, 4
// OPT: %1 = icmp ule i64 %0, 4
// OPT: call void @llvm.assume(i1 %1)
// CHECK-NOT: icmp
// CHECK-NOT: assume
// CHECK: %[[RET:.+]] = getelementptr i8, ptr null, i64 %x
// CHECK-NEXT: ret ptr %[[RET]]

transmute(x)
}
Loading