Skip to content

Commit 18f7ac5

Browse files
feat: Resolvable guarded compute pipelines (#1874)
1 parent b6e1199 commit 18f7ac5

File tree

4 files changed

+170
-0
lines changed

4 files changed

+170
-0
lines changed

apps/typegpu-docs/src/content/docs/fundamentals/utils.mdx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,16 @@ The default workgroup sizes are:
118118
The callback is not called if the global invocation id of a thread would exceed the size in any dimension.
119119
:::
120120

121+
:::tip
122+
`TgpuGuardedComputePipeline` provides getters for the underlying pipeline and the size buffer.
123+
Those might be useful for `tgpu.resolve`, since you cannot resolve a guarded pipeline directly.
124+
125+
```ts
126+
const innerPipeline = doubleUpPipeline.with(bindGroup1).pipeline;
127+
tgpu.resolve({ externals: { innerPipeline } });
128+
```
129+
:::
130+
121131
## *console.log*
122132

123133
Yes, you read that correctly, TypeGPU implements logging to the console on the GPU!

packages/typegpu/src/core/root/init.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ export class TgpuGuardedComputePipelineImpl<TArgs extends number[]>
185185
workgroupCount.z,
186186
);
187187
}
188+
189+
get pipeline() {
190+
return this.#pipeline;
191+
}
192+
193+
get sizeUniform() {
194+
return this.#sizeUniform;
195+
}
188196
}
189197

190198
class WithBindingImpl implements WithBinding {

packages/typegpu/src/core/root/rootTypes.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import type {
1313
AnyWgslData,
1414
U16,
1515
U32,
16+
Vec3u,
1617
Void,
1718
WgslArray,
1819
} from '../../data/wgslTypes.ts';
@@ -94,6 +95,17 @@ export interface TgpuGuardedComputePipeline<TArgs extends number[] = number[]> {
9495
* "guarded" by a bounds check.
9596
*/
9697
dispatchThreads(...args: TArgs): void;
98+
99+
/**
100+
* The underlying pipeline used during `dispatchThreads`.
101+
*/
102+
pipeline: TgpuComputePipeline;
103+
104+
/**
105+
* The buffer used to automatically pass the thread count to the underlying pipeline during `dispatchThreads`.
106+
* For pipelines with a dimension count lower than 3, the remaining coordinates are expected to be 1.
107+
*/
108+
sizeUniform: TgpuUniform<Vec3u>;
97109
}
98110

99111
export interface WithCompute {
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import { describe, expect } from 'vitest';
2+
import * as d from '../src/data/index.ts';
3+
import tgpu from '../src/index.ts';
4+
import { it } from './utils/extendedIt.ts';
5+
6+
describe('resolve', () => {
7+
const Boid = d.struct({
8+
position: d.vec2f,
9+
color: d.vec4f,
10+
});
11+
12+
const computeFn = tgpu['~unstable'].computeFn({
13+
workgroupSize: [1, 1, 1],
14+
in: { gid: d.builtin.globalInvocationId },
15+
})(() => {
16+
const myBoid = Boid({
17+
position: d.vec2f(0, 0),
18+
color: d.vec4f(1, 0, 0, 1),
19+
});
20+
});
21+
22+
const vertexFn = tgpu['~unstable'].vertexFn({
23+
out: { pos: d.builtin.position, color: d.vec4f },
24+
})(() => {
25+
const myBoid = Boid();
26+
return { pos: d.vec4f(myBoid.position, 0, 1), color: myBoid.color };
27+
});
28+
29+
const fragmentFn = tgpu['~unstable'].fragmentFn({
30+
in: { color: d.vec4f },
31+
out: d.vec4f,
32+
})((input) => {
33+
return input.color;
34+
});
35+
36+
it('can resolve a render pipeline', ({ root }) => {
37+
const pipeline = root
38+
.withVertex(vertexFn, {})
39+
.withFragment(fragmentFn, { format: 'rgba8unorm' })
40+
.createPipeline();
41+
42+
expect(tgpu.resolve({ externals: { pipeline } })).toMatchInlineSnapshot(`
43+
"struct Boid_1 {
44+
position: vec2f,
45+
color: vec4f,
46+
}
47+
48+
struct vertexFn_Output_2 {
49+
@builtin(position) pos: vec4f,
50+
@location(0) color: vec4f,
51+
}
52+
53+
@vertex fn vertexFn_0() -> vertexFn_Output_2 {
54+
var myBoid = Boid_1();
55+
return vertexFn_Output_2(vec4f(myBoid.position, 0f, 1f), myBoid.color);
56+
}
57+
58+
struct fragmentFn_Input_4 {
59+
@location(0) color: vec4f,
60+
}
61+
62+
@fragment fn fragmentFn_3(input: fragmentFn_Input_4) -> @location(0) vec4f {
63+
return input.color;
64+
}"
65+
`);
66+
});
67+
68+
it('can resolve a compute pipeline', ({ root }) => {
69+
const pipeline = root
70+
.withCompute(computeFn)
71+
.createPipeline();
72+
73+
expect(tgpu.resolve({ externals: { pipeline } })).toMatchInlineSnapshot(`
74+
"struct Boid_1 {
75+
position: vec2f,
76+
color: vec4f,
77+
}
78+
79+
struct computeFn_Input_2 {
80+
@builtin(global_invocation_id) gid: vec3u,
81+
}
82+
83+
@compute @workgroup_size(1, 1, 1) fn computeFn_0(_arg_0: computeFn_Input_2) {
84+
var myBoid = Boid_1(vec2f(), vec4f(1, 0, 0, 1));
85+
}"
86+
`);
87+
});
88+
89+
it('can resolve a guarded compute pipeline', ({ root }) => {
90+
const pipelineGuard = root.createGuardedComputePipeline((x, y, z) => {
91+
'use gpu';
92+
const myBoid = Boid({
93+
position: d.vec2f(0, 0),
94+
color: d.vec4f(x, y, z, 1),
95+
});
96+
});
97+
98+
expect(tgpu.resolve({ externals: { pipeline: pipelineGuard.pipeline } }))
99+
.toMatchInlineSnapshot(`
100+
"@group(0) @binding(0) var<uniform> sizeUniform_1: vec3u;
101+
102+
struct Boid_3 {
103+
position: vec2f,
104+
color: vec4f,
105+
}
106+
107+
fn wrappedCallback_2(x: u32, y: u32, z: u32) {
108+
var myBoid = Boid_3(vec2f(), vec4f(f32(x), f32(y), f32(z), 1f));
109+
}
110+
111+
struct mainCompute_Input_4 {
112+
@builtin(global_invocation_id) id: vec3u,
113+
}
114+
115+
@compute @workgroup_size(8, 8, 4) fn mainCompute_0(in: mainCompute_Input_4) {
116+
if (any(in.id >= sizeUniform_1)) {
117+
return;
118+
}
119+
wrappedCallback_2(in.id.x, in.id.y, in.id.z);
120+
}"
121+
`);
122+
});
123+
124+
it('throws when resolving multiple pipelines', ({ root }) => {
125+
const renderPipeline = root
126+
.withVertex(vertexFn, {})
127+
.withFragment(fragmentFn, { format: 'rgba8unorm' })
128+
.createPipeline();
129+
130+
const computePipeline = root
131+
.withCompute(computeFn)
132+
.createPipeline();
133+
134+
expect(() =>
135+
tgpu.resolve({ externals: { renderPipeline, computePipeline } })
136+
).toThrowErrorMatchingInlineSnapshot(
137+
`[Error: Found 2 pipelines but can only resolve one at a time.]`,
138+
);
139+
});
140+
});

0 commit comments

Comments
 (0)