Skip to content

Commit f867b5d

Browse files
committedJun 4, 2018
Add CUDA example
1 parent 46736bc commit f867b5d

File tree

3 files changed

+84
-6
lines changed

3 files changed

+84
-6
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ Cargo.lock
55
*.spv
66
*.so
77
*.png
8+
*.ppm

‎utility/mandelbrot.cu

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <assert.h>
4+
#include <stdint.h>
5+
#include <sys/time.h>
6+
7+
#define ITERATIONS 2000
8+
#define WIDTH 3840
9+
#define HEIGHT 2160
10+
#define CENTER_X -0.75
11+
#define CENTER_Y 0.0
12+
#define ZOOM (float(HEIGHT) / 2.5)
13+
14+
__global__ void mandelbrot(unsigned* dim, float* output) {
15+
unsigned width = dim[0];
16+
unsigned height = dim[1];
17+
unsigned tid = blockDim.x * blockIdx.x + threadIdx.x;
18+
for(; tid < width * height; tid += blockDim.x * gridDim.x) {
19+
float x = tid % width;
20+
float y = tid / width;
21+
x -= width / 2.0;
22+
y -= height / 2.0;
23+
x /= ZOOM;
24+
y /= ZOOM;
25+
x += CENTER_X;
26+
y += CENTER_Y;
27+
float a = 0.0, b = 0.0;
28+
for(unsigned i = 0; i < ITERATIONS; i++) {
29+
float tmp_a = a * a - b * b + x;
30+
b = 2.0 * a * b + y;
31+
a = tmp_a;
32+
}
33+
output[tid] = a * a + b * b;
34+
}
35+
}
36+
37+
int main() {
38+
unsigned* host_dim;
39+
float* host_output;
40+
unsigned* device_dim;
41+
float* device_output;
42+
struct timespec start, end;
43+
host_dim = (unsigned*)malloc(2 * sizeof(unsigned));
44+
assert(host_dim);
45+
host_output = (float*)malloc(WIDTH * HEIGHT * sizeof(float));
46+
assert(host_output);
47+
cudaMalloc(&device_dim, 2 * sizeof(unsigned));
48+
cudaMalloc(&device_output, WIDTH * HEIGHT * sizeof(float));
49+
host_dim[0] = WIDTH;
50+
host_dim[1] = HEIGHT;
51+
cudaMemcpy(device_dim, host_dim, 2 * sizeof(unsigned), cudaMemcpyHostToDevice);
52+
clock_gettime(CLOCK_MONOTONIC_RAW, &start);
53+
mandelbrot<<<(WIDTH * HEIGHT + 256) / 256, 256>>>(device_dim, device_output);
54+
cudaDeviceSynchronize();
55+
clock_gettime(CLOCK_MONOTONIC_RAW, &end);
56+
cudaMemcpy(host_output, device_output, WIDTH * HEIGHT * sizeof(float), cudaMemcpyDeviceToHost);
57+
cudaFree(device_output);
58+
cudaFree(device_dim);
59+
FILE* output = fopen("out.ppm", "w");
60+
fprintf(output, "P2\n%u %u\n255\n", host_dim[0], host_dim[1]);
61+
for(unsigned i = 0; i < WIDTH * HEIGHT; i++) {
62+
fprintf(output, "%d\n", (host_output[i] <= 2.0) ? (0) : (255));
63+
}
64+
fclose(output);
65+
free(host_dim);
66+
free(host_output);
67+
uint64_t delta_us = (end.tv_sec - start.tv_sec) * 1000000 + (end.tv_nsec - start.tv_nsec) / 1000;
68+
double delta = double(delta_us) / 1e6;
69+
printf("Time: %.3lf\n", delta);
70+
return 0;
71+
}

‎wyvern-python/examples/mandelbrot.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from time import time
1111
import png
1212

13+
WIDTH = 3840
14+
HEIGHT = 2160
15+
ITERATIONS = 2000
16+
17+
1318
def mandelbrot(g_ctx: wy.builder.Context, id: str, a0: str, b0: str,
1419
iterations: int):
1520
ctx = g_ctx.getProgramBuilder().newContext()
@@ -66,27 +71,28 @@ def loop():
6671
ctx.While(lambda: ctx.id < ctx.width * ctx.height, loop)
6772
return builder.finalize()
6873

74+
6975
if __name__ == "__main__":
70-
result = program(1920*1080, -0.75, 0.0, 1080/2.5, 2000)
76+
result = program(WIDTH * HEIGHT, -0.75, 0.0, HEIGHT / 2.5, ITERATIONS)
7177
executor = wy.WyVkExecutor()
7278
executable = executor.compile(result)
79+
start = time()
7380
input = executor.newResource()
7481
output = executor.newResource()
75-
input.set_data_array_uint32([1920, 1080])
76-
output.set_data_array_float32([0.0] * 1920 * 1080)
82+
input.set_data_array_uint32([WIDTH, HEIGHT])
83+
output.set_data_array_float32([0.0] * WIDTH * HEIGHT)
7784
executable.bind("input", wy.IoType.input.value, input)
7885
executable.bind("output", wy.IoType.output.value, output)
79-
start = time()
8086
executable.run()
81-
print("Time: %.3fs" % (time() - start),)
8287
result = output.get_data_array_float32()
88+
print("Time: %.3fs" % (time() - start),)
8389
def mapper(x):
8490
if x <= 2.0:
8591
return 0
8692
else:
8793
return 255
8894
result = [mapper(x) for x in result]
8995
out = open("out.png", "wb")
90-
w = png.Writer(1920, 1080, greyscale=True)
96+
w = png.Writer(WIDTH, HEIGHT, greyscale=True)
9197
w.write_array(out, result)
9298
out.close()

0 commit comments

Comments
 (0)