Skip to content

[hal metal] ray tracing acceleration structures #7660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
898c357
Removes Option<> around AccelerationStructureTriangleIndices::buffer.
Lichtso May 3, 2025
3be8c76
Removes Option<> around AccelerationStructureTriangles::vertex_buffer.
Lichtso May 3, 2025
f3e14ff
Removes Option<> around AccelerationStructureAABBs::buffer.
Lichtso May 3, 2025
819fae2
Removes Option<> around AccelerationStructureInstances::buffer.
Lichtso May 3, 2025
517ba01
Fixes index_buffer label in ray_traced_triangle example.
Lichtso May 2, 2025
ef47bc0
Fixes min_push_constant_size in ray_shadows example.
Lichtso May 3, 2025
5fe6530
Updates CHANGELOG.
Lichtso Apr 30, 2025
60c3576
Adds feature detection.
Lichtso Apr 30, 2025
540d040
Sets raw_tlas_instance_size.
Lichtso Apr 30, 2025
d54f3c6
Sets ray_tracing_scratch_buffer_alignment.
Lichtso May 2, 2025
26c5acb
Adds conv::map_index_format().
Lichtso Apr 30, 2025
483bcff
Adds conv::map_acceleration_structure_descriptor().
Lichtso Apr 30, 2025
bd3366c
Adds AccelerationStructurePtr.
Lichtso Apr 30, 2025
dda8095
Implements AccelerationStructure.
Lichtso Apr 30, 2025
ec3d18e
Adds CommandState::acceleration_structure_builder.
Lichtso Apr 30, 2025
d885189
Implements CommandEncoder::copy_acceleration_structure_to_acceleratio…
Lichtso Apr 30, 2025
5ccd34f
Implements CommandEncoder::build_acceleration_structures().
Lichtso Apr 30, 2025
b4d2b75
Implements CommandEncoder::place_acceleration_structure_barrier().
Lichtso May 2, 2025
aca2014
Implements CommandEncoder::read_acceleration_structure_compact_size().
Lichtso Apr 30, 2025
2212733
Implements Device::get_acceleration_structure_build_sizes().
Lichtso Apr 30, 2025
8615600
Implements Device::get_acceleration_structure_device_address().
Lichtso May 2, 2025
438742e
Implements Device::create_acceleration_structure().
Lichtso Apr 30, 2025
6902c45
Implements Device::destroy_acceleration_structure().
Lichtso Apr 30, 2025
4915a8f
Implements Device::tlas_instance_to_bytes().
Lichtso Apr 30, 2025
982103f
Implements resource binding.
Lichtso Apr 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ Naga now infers the correct binding layout when a resource appears only in an as

- Use highest SPIR-V version supported by Vulkan API version. By @robamler in [#7595](https://github.com/gfx-rs/wgpu/pull/7595)

#### Metal

- Implements ray-tracing acceleration structures for metal backend. By @lichtso in [#7660](https://github.com/gfx-rs/wgpu/pull/7660)

### Bug Fixes

#### Naga
Expand Down
4 changes: 2 additions & 2 deletions examples/features/src/ray_shadows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl crate::framework::Example for Example {

fn required_limits() -> wgpu::Limits {
wgpu::Limits {
max_push_constant_size: 12,
max_push_constant_size: 16,
..wgpu::Limits::default()
}
}
Expand Down Expand Up @@ -209,7 +209,7 @@ impl crate::framework::Example for Example {
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[wgpu::PushConstantRange {
stages: wgpu::ShaderStages::FRAGMENT,
range: 0..12,
range: 0..16,
}],
});

Expand Down
1 change: 1 addition & 0 deletions examples/features/src/ray_shadows/shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var acc_struct: acceleration_structure;

struct PushConstants {
light: vec3<f32>,
padding: f32,
Copy link
Contributor Author

@Lichtso Lichtso May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that metal always sends at least 16 bytes for push constants, even if we only pass in 12 bytes. And then the shader validation complains that the receiver here only expects 12 bytes.

}
var<push_constant> pc: PushConstants;

Expand Down
2 changes: 1 addition & 1 deletion examples/features/src/ray_traced_triangle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl crate::framework::Example for Example {
});

let index_buffer = device.create_buffer_init(&BufferInitDescriptor {
label: Some("vertex buffer"),
label: Some("index buffer"),
contents: bytemuck::cast_slice(&indices),
usage: BufferUsages::BLAS_INPUT,
});
Expand Down
8 changes: 4 additions & 4 deletions wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ impl Global {
tlas,
entries: hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: Some(instance_buffer),
buffer: instance_buffer,
offset: 0,
count: entry.instance_count,
},
Expand Down Expand Up @@ -602,7 +602,7 @@ impl Global {
tlas: tlas.clone(),
entries: hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: Some(tlas.instance_buffer.as_ref()),
buffer: tlas.instance_buffer.as_ref(),
offset: 0,
count: instance_count,
},
Expand Down Expand Up @@ -1141,7 +1141,7 @@ fn iter_buffers<'a, 'b>(
};

let triangles = hal::AccelerationStructureTriangles {
vertex_buffer: Some(vertex_buffer),
vertex_buffer,
vertex_format: mesh.size.vertex_format,
first_vertex: mesh.first_vertex,
vertex_count: mesh.size.vertex_count,
Expand All @@ -1150,7 +1150,7 @@ fn iter_buffers<'a, 'b>(
let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
format: mesh.size.index_format.unwrap(),
buffer: Some(index_buffer),
buffer: index_buffer,
offset: mesh.first_index.unwrap() * index_stride,
count: mesh.size.index_count.unwrap(),
}
Expand Down
6 changes: 3 additions & 3 deletions wgpu-core/src/device/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl Device {
dyn hal::DynBuffer,
> {
format: desc.index_format.unwrap(),
buffer: None,
buffer: self.zero_buffer.as_ref(),
offset: 0,
count,
});
Expand Down Expand Up @@ -78,7 +78,7 @@ impl Device {
}

entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
vertex_buffer: None,
vertex_buffer: self.zero_buffer.as_ref(),
vertex_format: desc.vertex_format,
first_vertex: 0,
vertex_count: desc.vertex_count,
Expand Down Expand Up @@ -158,7 +158,7 @@ impl Device {
&hal::GetAccelerationStructureBuildSizesDescriptor {
entries: &hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: None,
buffer: self.zero_buffer.as_ref(),
offset: 0,
count: desc.max_instances,
},
Expand Down
1 change: 1 addition & 0 deletions wgpu-hal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ metal = [
"naga/msl-out",
"dep:arrayvec",
"dep:block",
"dep:bytemuck",
"dep:core-graphics-types",
"dep:hashbrown",
"dep:libc",
Expand Down
182 changes: 87 additions & 95 deletions wgpu-hal/examples/ray-traced-triangle/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,15 @@ impl<A: hal::Api> Example<A> {
};

let blas_triangles = vec![hal::AccelerationStructureTriangles {
vertex_buffer: Some(&vertices_buffer),
vertex_buffer: &vertices_buffer,
first_vertex: 0,
vertex_format: wgpu_types::VertexFormat::Float32x3,
// each vertex is 3 floats, and floats are stored raw in the array
vertex_count: vertices.len() as u32 / 3,
vertex_stride: 3 * 4,
indices: indices_buffer.as_ref().map(|(buf, len)| {
indices: indices_buffer.as_ref().map(|(buffer, len)| {
hal::AccelerationStructureTriangleIndices {
buffer: Some(buf),
buffer,
format: wgpu_types::IndexFormat::Uint32,
offset: 0,
count: *len as u32,
Expand All @@ -493,13 +493,6 @@ impl<A: hal::Api> Example<A> {
}];
let blas_entries = hal::AccelerationStructureEntries::Triangles(blas_triangles);

let mut tlas_entries =
hal::AccelerationStructureEntries::Instances(hal::AccelerationStructureInstances {
buffer: None,
count: 3,
offset: 0,
});

let blas_sizes = unsafe {
device.get_acceleration_structure_build_sizes(
&hal::GetAccelerationStructureBuildSizesDescriptor {
Expand All @@ -509,6 +502,89 @@ impl<A: hal::Api> Example<A> {
)
};

let blas = unsafe {
device.create_acceleration_structure(&hal::AccelerationStructureDescriptor {
label: Some("blas"),
size: blas_sizes.acceleration_structure_size,
format: hal::AccelerationStructureFormat::BottomLevel,
allow_compaction: false,
})
}
.unwrap();

let instances = [
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: 0.0,
y: 0.0,
z: 0.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: -1.0,
y: -1.0,
z: -2.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: 1.0,
y: -1.0,
z: -2.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
];

let instances_buffer_size = instances.len() * size_of::<AccelerationStructureInstance>();

let instances_buffer = unsafe {
let instances_buffer = device
.create_buffer(&hal::BufferDescriptor {
label: Some("instances_buffer"),
size: instances_buffer_size as u64,
usage: wgpu_types::BufferUses::MAP_WRITE
| wgpu_types::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT,
})
.unwrap();

let mapping = device
.map_buffer(&instances_buffer, 0..instances_buffer_size as u64)
.unwrap();
ptr::copy_nonoverlapping(
instances.as_ptr() as *const u8,
mapping.ptr.as_ptr(),
instances_buffer_size,
);
device.unmap_buffer(&instances_buffer);
assert!(mapping.is_coherent);

instances_buffer
};

let tlas_entries =
hal::AccelerationStructureEntries::Instances(hal::AccelerationStructureInstances {
buffer: &instances_buffer,
count: 3,
offset: 0,
});

let tlas_flags = hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE
| hal::AccelerationStructureBuildFlags::ALLOW_UPDATE;

Expand All @@ -521,16 +597,6 @@ impl<A: hal::Api> Example<A> {
)
};

let blas = unsafe {
device.create_acceleration_structure(&hal::AccelerationStructureDescriptor {
label: Some("blas"),
size: blas_sizes.acceleration_structure_size,
format: hal::AccelerationStructureFormat::BottomLevel,
allow_compaction: false,
})
}
.unwrap();

let tlas = unsafe {
device.create_acceleration_structure(&hal::AccelerationStructureDescriptor {
label: Some("tlas"),
Expand Down Expand Up @@ -653,80 +719,6 @@ impl<A: hal::Api> Example<A> {
.unwrap()
};

let instances = [
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: 0.0,
y: 0.0,
z: 0.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: -1.0,
y: -1.0,
z: -2.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
AccelerationStructureInstance::new(
&Affine3A::from_translation(Vec3 {
x: 1.0,
y: -1.0,
z: -2.0,
}),
0,
0xff,
0,
0,
unsafe { device.get_acceleration_structure_device_address(&blas) },
),
];

let instances_buffer_size = instances.len() * size_of::<AccelerationStructureInstance>();

let instances_buffer = unsafe {
let instances_buffer = device
.create_buffer(&hal::BufferDescriptor {
label: Some("instances_buffer"),
size: instances_buffer_size as u64,
usage: wgpu_types::BufferUses::MAP_WRITE
| wgpu_types::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT,
})
.unwrap();

let mapping = device
.map_buffer(&instances_buffer, 0..instances_buffer_size as u64)
.unwrap();
ptr::copy_nonoverlapping(
instances.as_ptr() as *const u8,
mapping.ptr.as_ptr(),
instances_buffer_size,
);
device.unmap_buffer(&instances_buffer);
assert!(mapping.is_coherent);

instances_buffer
};

if let hal::AccelerationStructureEntries::Instances(ref mut i) = tlas_entries {
i.buffer = Some(&instances_buffer);
assert!(
instances.len() <= i.count as usize,
"Tlas allocation to small"
);
}

let cmd_encoder_desc = hal::CommandEncoderDescriptor {
label: None,
queue: &queue,
Expand Down Expand Up @@ -903,7 +895,7 @@ impl<A: hal::Api> Example<A> {
ctx.encoder.begin_encoding(Some("frame")).unwrap();

let instances = hal::AccelerationStructureInstances {
buffer: Some(&self.instances_buffer),
buffer: &self.instances_buffer,
count: self.instances.len() as u32,
offset: 0,
};
Expand Down
27 changes: 5 additions & 22 deletions wgpu-hal/src/dx12/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1476,13 +1476,8 @@ impl crate::CommandEncoder for super::CommandEncoder {
let num_desc;
match descriptor.entries {
AccelerationStructureEntries::Instances(instances) => {
let desc_address = unsafe {
instances
.buffer
.expect("needs buffer to build")
.resource
.GetGPUVirtualAddress()
} + instances.offset as u64;
let desc_address = unsafe { instances.buffer.resource.GetGPUVirtualAddress() }
+ instances.offset as u64;
ty = Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL;
inputs0 = Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 {
InstanceDescs: desc_address,
Expand All @@ -1508,19 +1503,10 @@ impl crate::CommandEncoder for super::CommandEncoder {
let index_count =
triangle.indices.as_ref().map_or(0, |indices| indices.count);
let index_address = triangle.indices.as_ref().map_or(0, |indices| unsafe {
indices
.buffer
.expect("needs buffer to build")
.resource
.GetGPUVirtualAddress()
+ indices.offset as u64
indices.buffer.resource.GetGPUVirtualAddress() + indices.offset as u64
});
let vertex_address = unsafe {
triangle
.vertex_buffer
.expect("needs buffer to build")
.resource
.GetGPUVirtualAddress()
triangle.vertex_buffer.resource.GetGPUVirtualAddress()
+ (triangle.first_vertex as u64 * triangle.vertex_stride)
};

Expand Down Expand Up @@ -1555,10 +1541,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
geometry_desc = Vec::with_capacity(aabbs.len());
for aabb in aabbs {
let aabb_address = unsafe {
aabb.buffer
.expect("needs buffer to build")
.resource
.GetGPUVirtualAddress()
aabb.buffer.resource.GetGPUVirtualAddress()
+ (aabb.offset as u64 * aabb.stride)
};

Expand Down
Loading
Loading