Skip to content

Commit fab5751

Browse files
Add metallib passthrough (#8886)
1 parent 1af7312 commit fab5751

File tree

26 files changed

+891
-154
lines changed

26 files changed

+891
-154
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Bottom level categories:
6060
- `front::wgsl::Frontend::set_options`
6161
- `ir::Block::is_empty`
6262
- `ir::Block::len`
63+
- Changed passthrough shaders to not require an entry point parameter, so that the same shader module may be used in multiple entry points. Also added support for metallib passthrough. By @inner-daemons in #8886.
6364

6465
#### naga
6566

Cargo.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ xshell = "0.2.2"
203203

204204
# Metal dependencies
205205
block2 = "0.6.2"
206+
dispatch2 = "0.3.0"
206207
objc2 = "0.6.3"
207208
objc2-core-foundation = { version = "0.3.2", default-features = false, features = [
208209
"std",
@@ -218,6 +219,7 @@ objc2-foundation = { version = "0.3.2", default-features = false, features = [
218219
objc2-metal = { version = "0.3.2", default-features = false, features = [
219220
"std",
220221
"block2",
222+
"dispatch2",
221223
"MTLAllocation",
222224
"MTLBlitCommandEncoder",
223225
"MTLBlitPass",

examples/features/src/mesh_shader/mod.rs

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::Sh
2929
std::fs::remove_file(out_path).unwrap();
3030
unsafe {
3131
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
32-
entry_point: entry.to_owned(),
3332
label: None,
3433
num_workgroups: (1, 1, 1),
3534
dxil: Some(std::borrow::Cow::Owned(file)),
@@ -38,10 +37,9 @@ fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::Sh
3837
}
3938
}
4039

41-
fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule {
40+
fn compile_msl(device: &wgpu::Device) -> wgpu::ShaderModule {
4241
unsafe {
4342
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
44-
entry_point: entry.to_owned(),
4543
label: None,
4644
msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))),
4745
num_workgroups: (1, 1, 1),
@@ -50,6 +48,53 @@ fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule {
5048
}
5149
}
5250

51+
struct Shaders {
52+
ts: wgpu::ShaderModule,
53+
ms: wgpu::ShaderModule,
54+
fs: wgpu::ShaderModule,
55+
ts_name: &'static str,
56+
ms_name: &'static str,
57+
fs_name: &'static str,
58+
}
59+
60+
fn get_shaders(device: &wgpu::Device, backend: wgpu::Backend) -> Shaders {
61+
// In the case that the platform does support mesh shaders, the dummy
62+
// shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS.
63+
match backend {
64+
wgpu::Backend::Vulkan => {
65+
let compiled = compile_wgsl(device);
66+
Shaders {
67+
ts: compiled.clone(),
68+
ms: compiled.clone(),
69+
fs: compiled.clone(),
70+
ts_name: "ts_main",
71+
ms_name: "ms_main",
72+
fs_name: "fs_main",
73+
}
74+
}
75+
wgpu::Backend::Dx12 => Shaders {
76+
ts: compile_hlsl(device, "Task", "as"),
77+
ms: compile_hlsl(device, "Mesh", "ms"),
78+
fs: compile_hlsl(device, "Frag", "ps"),
79+
ts_name: "main",
80+
ms_name: "main",
81+
fs_name: "main",
82+
},
83+
wgpu::Backend::Metal => {
84+
let compiled = compile_msl(device);
85+
Shaders {
86+
ts: compiled.clone(),
87+
ms: compiled.clone(),
88+
fs: compiled.clone(),
89+
ts_name: "taskShader",
90+
ms_name: "meshShader",
91+
fs_name: "fragShader",
92+
}
93+
}
94+
_ => unreachable!(),
95+
}
96+
}
97+
5398
pub struct Example {
5499
pipeline: wgpu::RenderPipeline,
55100
}
@@ -60,33 +105,14 @@ impl crate::framework::Example for Example {
60105
device: &wgpu::Device,
61106
_queue: &wgpu::Queue,
62107
) -> Self {
63-
let (ts, ms, fs, ts_name, ms_name, fs_name) = match adapter.get_info().backend {
64-
wgpu::Backend::Vulkan => (
65-
compile_wgsl(device),
66-
compile_wgsl(device),
67-
compile_wgsl(device),
68-
"ts_main",
69-
"ms_main",
70-
"fs_main",
71-
),
72-
wgpu::Backend::Dx12 => (
73-
compile_hlsl(device, "Task", "as"),
74-
compile_hlsl(device, "Mesh", "ms"),
75-
compile_hlsl(device, "Frag", "ps"),
76-
"main",
77-
"main",
78-
"main",
79-
),
80-
wgpu::Backend::Metal => (
81-
compile_msl(device, "taskShader"),
82-
compile_msl(device, "meshShader"),
83-
compile_msl(device, "fragShader"),
84-
"main",
85-
"main",
86-
"main",
87-
),
88-
_ => panic!("Example can currently only run on vulkan, dx12 or metal"),
89-
};
108+
let Shaders {
109+
ts,
110+
ms,
111+
fs,
112+
ts_name,
113+
ms_name,
114+
fs_name,
115+
} = get_shaders(device, adapter.get_info().backend);
90116
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
91117
label: None,
92118
bind_group_layouts: &[],

player/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,8 @@ impl Player {
268268
Action::CreateShaderModulePassthrough {
269269
id,
270270
data,
271-
entry_point,
272271
label,
273272
num_workgroups,
274-
runtime_checks,
275273
} => {
276274
let spirv = data.iter().find_map(|a| {
277275
if a.kind() == DataKind::Spv {
@@ -289,6 +287,9 @@ impl Player {
289287
let hlsl = data
290288
.iter()
291289
.find_map(|a| (a.kind() == DataKind::Hlsl).then(|| loader.load_utf8(a)));
290+
let metallib = data
291+
.iter()
292+
.find_map(|a| (a.kind() == DataKind::MetalLib).then(|| loader.load(a)));
292293
let msl = data
293294
.iter()
294295
.find_map(|a| (a.kind() == DataKind::Msl).then(|| loader.load_utf8(a)));
@@ -299,14 +300,13 @@ impl Player {
299300
.iter()
300301
.find_map(|a| (a.kind() == DataKind::Wgsl).then(|| loader.load_utf8(a)));
301302
let desc = wgt::CreateShaderModuleDescriptorPassthrough {
302-
entry_point,
303303
label,
304304
num_workgroups,
305-
runtime_checks,
306305

307306
spirv,
308307
dxil,
309308
hlsl,
309+
metallib,
310310
msl,
311311
glsl,
312312
wgsl,

tests/tests/wgpu-gpu/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ mod occlusion_query;
4242
mod oob_indexing;
4343
mod oom;
4444
mod pass_ops;
45+
mod passthrough;
4546
mod per_vertex;
4647
mod pipeline;
4748
mod pipeline_cache;
@@ -107,6 +108,7 @@ fn all_tests() -> Vec<wgpu_test::GpuTestInitializer> {
107108
oob_indexing::all_tests(&mut tests);
108109
oom::all_tests(&mut tests);
109110
pass_ops::all_tests(&mut tests);
111+
passthrough::all_tests(&mut tests);
110112
per_vertex::all_tests(&mut tests);
111113
pipeline_cache::all_tests(&mut tests);
112114
pipeline::all_tests(&mut tests);

0 commit comments

Comments
 (0)