Skip to content

Commit 85c8ff6

Browse files
committed
Auto merge of rust-lang#150606 - JonathanBrouwer:rollup-lue4jqz, r=JonathanBrouwer
Rollup of 6 pull requests Successful merges: - rust-lang#150425 (mapping an error from cmd.spawn() in npm::install) - rust-lang#150444 (Expose kernel launch options as offload intrinsic args) - rust-lang#150495 (Correct hexagon "unwinder_private_data_size") - rust-lang#150578 (Fix a typo in the docs of AsMut for rust-lang#149609) - rust-lang#150581 (mir_build: Separate match lowering for string-equality and scalar-equality) - rust-lang#150594 (Fix typo in the docs of `CString::from_vec_with_nul`) r? `@ghost` `@rustbot` modify labels: rollup
2 parents 5497a36 + 3dc296a commit 85c8ff6

16 files changed

Lines changed: 233 additions & 97 deletions

File tree

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use std::ffi::CString;
22

33
use llvm::Linkage::*;
44
use rustc_abi::Align;
5+
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
56
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
7+
use rustc_middle::bug;
68
use rustc_middle::ty::offload_meta::OffloadMetadata;
79

810
use crate::builder::Builder;
@@ -69,6 +71,57 @@ impl<'ll> OffloadGlobals<'ll> {
6971
}
7072
}
7173

74+
pub(crate) struct OffloadKernelDims<'ll> {
75+
num_workgroups: &'ll Value,
76+
threads_per_block: &'ll Value,
77+
workgroup_dims: &'ll Value,
78+
thread_dims: &'ll Value,
79+
}
80+
81+
impl<'ll> OffloadKernelDims<'ll> {
82+
pub(crate) fn from_operands<'tcx>(
83+
builder: &mut Builder<'_, 'll, 'tcx>,
84+
workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
85+
thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
86+
) -> Self {
87+
let cx = builder.cx;
88+
let arr_ty = cx.type_array(cx.type_i32(), 3);
89+
let four = Align::from_bytes(4).unwrap();
90+
91+
let OperandValue::Ref(place) = workgroup_op.val else {
92+
bug!("expected array operand by reference");
93+
};
94+
let workgroup_val = builder.load(arr_ty, place.llval, four);
95+
96+
let OperandValue::Ref(place) = thread_op.val else {
97+
bug!("expected array operand by reference");
98+
};
99+
let thread_val = builder.load(arr_ty, place.llval, four);
100+
101+
fn mul_dim3<'ll, 'tcx>(
102+
builder: &mut Builder<'_, 'll, 'tcx>,
103+
arr: &'ll Value,
104+
) -> &'ll Value {
105+
let x = builder.extract_value(arr, 0);
106+
let y = builder.extract_value(arr, 1);
107+
let z = builder.extract_value(arr, 2);
108+
109+
let xy = builder.mul(x, y);
110+
builder.mul(xy, z)
111+
}
112+
113+
let num_workgroups = mul_dim3(builder, workgroup_val);
114+
let threads_per_block = mul_dim3(builder, thread_val);
115+
116+
OffloadKernelDims {
117+
workgroup_dims: workgroup_val,
118+
thread_dims: thread_val,
119+
num_workgroups,
120+
threads_per_block,
121+
}
122+
}
123+
}
124+
72125
// ; Function Attrs: nounwind
73126
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
74127
fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
@@ -204,12 +257,12 @@ impl KernelArgsTy {
204257
num_args: u64,
205258
memtransfer_types: &'ll Value,
206259
geps: [&'ll Value; 3],
260+
workgroup_dims: &'ll Value,
261+
thread_dims: &'ll Value,
207262
) -> [(Align, &'ll Value); 13] {
208263
let four = Align::from_bytes(4).expect("4 Byte alignment should work");
209264
let eight = Align::EIGHT;
210265

211-
let ti32 = cx.type_i32();
212-
let ci32_0 = cx.get_const_i32(0);
213266
[
214267
(four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
215268
(four, cx.get_const_i32(num_args)),
@@ -222,8 +275,8 @@ impl KernelArgsTy {
222275
(eight, cx.const_null(cx.type_ptr())), // dbg
223276
(eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
224277
(eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
225-
(four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
226-
(four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
278+
(four, workgroup_dims),
279+
(four, thread_dims),
227280
(four, cx.get_const_i32(0)),
228281
]
229282
}
@@ -413,10 +466,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
413466
types: &[&Type],
414467
metadata: &[OffloadMetadata],
415468
offload_globals: &OffloadGlobals<'ll>,
469+
offload_dims: &OffloadKernelDims<'ll>,
416470
) {
417471
let cx = builder.cx;
418472
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
419473
offload_data;
474+
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
475+
offload_dims;
420476

421477
let tgt_decl = offload_globals.launcher_fn;
422478
let tgt_target_kernel_ty = offload_globals.launcher_ty;
@@ -554,7 +610,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
554610
num_args,
555611
s_ident_t,
556612
);
557-
let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
613+
let values =
614+
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);
558615

559616
// Step 3)
560617
// Here we fill the KernelArgsTy, see the documentation above
@@ -567,9 +624,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
567624
s_ident_t,
568625
// FIXME(offload) give users a way to select which GPU to use.
569626
cx.get_const_i64(u64::MAX), // MAX == -1.
570-
// FIXME(offload): Don't hardcode the numbers of threads in the future.
571-
cx.get_const_i32(2097152),
572-
cx.get_const_i32(256),
627+
num_workgroups,
628+
threads_per_block,
573629
region_id,
574630
a5,
575631
];

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use tracing::debug;
3030
use crate::abi::FnAbiLlvmExt;
3131
use crate::builder::Builder;
3232
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
33-
use crate::builder::gpu_offload::{gen_call_handling, gen_define_handling};
33+
use crate::builder::gpu_offload::{OffloadKernelDims, gen_call_handling, gen_define_handling};
3434
use crate::context::CodegenCx;
3535
use crate::declare::declare_raw_fn;
3636
use crate::errors::{
@@ -1384,7 +1384,8 @@ fn codegen_offload<'ll, 'tcx>(
13841384
}
13851385
};
13861386

1387-
let args = get_args_from_tuple(bx, args[1], fn_target);
1387+
let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]);
1388+
let args = get_args_from_tuple(bx, args[3], fn_target);
13881389
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);
13891390

13901391
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
@@ -1403,7 +1404,7 @@ fn codegen_offload<'ll, 'tcx>(
14031404
}
14041405
};
14051406
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
1406-
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals);
1407+
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
14071408
}
14081409

14091410
fn get_args_from_tuple<'ll, 'tcx>(

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_abi::ExternAbi;
44
use rustc_errors::DiagMessage;
55
use rustc_hir::{self as hir, LangItem};
66
use rustc_middle::traits::{ObligationCause, ObligationCauseCode};
7-
use rustc_middle::ty::{self, Ty, TyCtxt};
7+
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
88
use rustc_span::def_id::LocalDefId;
99
use rustc_span::{Span, Symbol, sym};
1010

@@ -315,7 +315,17 @@ pub(crate) fn check_intrinsic_type(
315315
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
316316
(0, 0, vec![type_id, type_id], tcx.types.bool)
317317
}
318-
sym::offload => (3, 0, vec![param(0), param(1)], param(2)),
318+
sym::offload => (
319+
3,
320+
0,
321+
vec![
322+
param(0),
323+
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
324+
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
325+
param(1),
326+
],
327+
param(2),
328+
),
319329
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
320330
sym::arith_offset => (
321331
1,

compiler/rustc_middle/src/ty/sty.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,12 @@ impl<'tcx> Ty<'tcx> {
12181218
*self.kind() == Str
12191219
}
12201220

1221+
/// Returns true if this type is `&str`. The reference's lifetime is ignored.
1222+
#[inline]
1223+
pub fn is_imm_ref_str(self) -> bool {
1224+
matches!(self.kind(), ty::Ref(_, inner, hir::Mutability::Not) if inner.is_str())
1225+
}
1226+
12211227
#[inline]
12221228
pub fn is_param(self, index: u32) -> bool {
12231229
match self.kind() {

compiler/rustc_mir_build/src/builder/matches/buckets.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
314314
}
315315

316316
(
317-
TestKind::Eq { value: test_val, .. },
317+
TestKind::StringEq { value: test_val, .. },
318+
TestableCase::Constant { value: case_val, kind: PatConstKind::String },
319+
)
320+
| (
321+
TestKind::ScalarEq { value: test_val, .. },
318322
TestableCase::Constant {
319323
value: case_val,
320324
kind: PatConstKind::Float | PatConstKind::Other,
@@ -347,7 +351,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
347351
| TestKind::If
348352
| TestKind::SliceLen { .. }
349353
| TestKind::Range { .. }
350-
| TestKind::Eq { .. }
354+
| TestKind::StringEq { .. }
355+
| TestKind::ScalarEq { .. }
351356
| TestKind::Deref { .. },
352357
_,
353358
) => {

compiler/rustc_mir_build/src/builder/matches/match_pair.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::sync::Arc;
22

33
use rustc_abi::FieldIdx;
44
use rustc_middle::mir::*;
5+
use rustc_middle::span_bug;
56
use rustc_middle::thir::*;
67
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
78

@@ -173,9 +174,21 @@ impl<'tcx> MatchPairTree<'tcx> {
173174
PatConstKind::IntOrChar
174175
} else if pat_ty.is_floating_point() {
175176
PatConstKind::Float
177+
} else if pat_ty.is_str() {
178+
// Deref-patterns can cause string-literal patterns to have
179+
// type `str` instead of the usual `&str`.
180+
if !cx.tcx.features().deref_patterns() {
181+
span_bug!(
182+
pattern.span,
183+
"const pattern has type `str` but deref_patterns is not enabled"
184+
);
185+
}
186+
PatConstKind::String
187+
} else if pat_ty.is_imm_ref_str() {
188+
PatConstKind::String
176189
} else {
177190
// FIXME(Zalathar): This still covers several different
178-
// categories (e.g. raw pointer, string, pattern-type)
191+
// categories (e.g. raw pointer, pattern-type)
179192
// which could be split out into their own kinds.
180193
PatConstKind::Other
181194
};

compiler/rustc_mir_build/src/builder/matches/mod.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,9 +1290,10 @@ enum PatConstKind {
12901290
/// These types don't support `SwitchInt` and require an equality test,
12911291
/// but can also interact with range pattern tests.
12921292
Float,
1293+
/// Constant string values, tested via string equality.
1294+
String,
12931295
/// Any other constant-pattern is usually tested via some kind of equality
12941296
/// check. Types that might be encountered here include:
1295-
/// - `&str`
12961297
/// - raw pointers derived from integer values
12971298
/// - pattern types, e.g. `pattern_type!(u32 is 1..)`
12981299
Other,
@@ -1368,14 +1369,20 @@ enum TestKind<'tcx> {
13681369
/// Test whether a `bool` is `true` or `false`.
13691370
If,
13701371

1371-
/// Test for equality with value, possibly after an unsizing coercion to
1372-
/// `cast_ty`,
1373-
Eq {
1372+
/// Tests the place against a string constant using string equality.
1373+
StringEq {
1374+
/// Constant `&str` value to test against.
13741375
value: ty::Value<'tcx>,
1375-
// Integer types are handled by `SwitchInt`, and constants with ADT
1376-
// types and `&[T]` types are converted back into patterns, so this can
1377-
// only be `&str` or floats.
1378-
cast_ty: Ty<'tcx>,
1376+
/// Type of the corresponding pattern node. Usually `&str`, but could
1377+
/// be `str` for patterns like `deref!("..."): String`.
1378+
pat_ty: Ty<'tcx>,
1379+
},
1380+
1381+
/// Tests the place against a constant using scalar equality.
1382+
ScalarEq {
1383+
value: ty::Value<'tcx>,
1384+
/// Type of the corresponding pattern node.
1385+
pat_ty: Ty<'tcx>,
13791386
},
13801387

13811388
/// Test whether the value falls within an inclusive or exclusive range.

0 commit comments

Comments
 (0)