Skip to content

Commit 828caef

Browse files
committed
vulkan: support im2col_3d
1 parent c1c354e commit 828caef

File tree

4 files changed

+290
-7
lines changed

4 files changed

+290
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ struct vk_device_struct {
554554
vk_pipeline pipeline_argmax_f32;
555555
vk_pipeline pipeline_count_equal_i32;
556556
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
557+
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
557558
vk_pipeline pipeline_timestep_embedding_f32;
558559
vk_pipeline pipeline_conv_transpose_1d_f32;
559560
vk_pipeline pipeline_pool2d_f32;
@@ -931,6 +932,37 @@ struct vk_op_im2col_push_constants {
931932
int32_t d0; int32_t d1;
932933
};
933934

935+
struct vk_op_im2col_3d_push_constants {
936+
uint32_t nb10;
937+
uint32_t nb11;
938+
uint32_t nb12;
939+
uint32_t nb13;
940+
uint32_t s0;
941+
uint32_t s1;
942+
uint32_t s2;
943+
uint32_t p0;
944+
uint32_t p1;
945+
uint32_t p2;
946+
uint32_t d0;
947+
uint32_t d1;
948+
uint32_t d2;
949+
uint32_t IW;
950+
uint32_t IH;
951+
uint32_t ID;
952+
uint32_t IC;
953+
uint32_t KW;
954+
uint32_t OH;
955+
uint32_t KD_KH_KW;
956+
uint32_t KH_KW;
957+
uint32_t IC_KD_KH_KW;
958+
uint32_t N_OD_OH;
959+
uint32_t OD_OH;
960+
uint32_t OD_OH_OW_IC_KD_KH_KW;
961+
uint32_t OH_OW_IC_KD_KH_KW;
962+
uint32_t OW_IC_KD_KH_KW;
963+
uint32_t misalign_offsets;
964+
};
965+
934966
struct vk_op_timestep_embedding_push_constants {
935967
uint32_t nb1;
936968
uint32_t dim;
@@ -3329,10 +3361,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
33293361
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
33303362

33313363
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3364+
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
33323365
if (device->float_controls_rte_fp16) {
33333366
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3367+
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
33343368
} else {
33353369
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3370+
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
33363371
}
33373372

33383373
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -7666,6 +7701,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
76667701
return ctx->device->pipeline_im2col_f32_f16;
76677702
}
76687703
return nullptr;
7704+
case GGML_OP_IM2COL_3D:
7705+
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7706+
return ctx->device->pipeline_im2col_3d_f32;
7707+
}
7708+
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
7709+
return ctx->device->pipeline_im2col_3d_f32_f16;
7710+
}
7711+
return nullptr;
76697712
case GGML_OP_TIMESTEP_EMBEDDING:
76707713
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
76717714
return ctx->device->pipeline_timestep_embedding_f32;
@@ -7781,6 +7824,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
77817824
case GGML_OP_RMS_NORM:
77827825
case GGML_OP_CONV_2D_DW:
77837826
case GGML_OP_IM2COL:
7827+
case GGML_OP_IM2COL_3D:
77847828
case GGML_OP_SET_ROWS:
77857829
case GGML_OP_SUM:
77867830
case GGML_OP_SUM_ROWS:
@@ -7829,6 +7873,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
78297873
GGML_UNUSED(src2);
78307874
}
78317875

7876+
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7877+
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
7878+
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
7879+
7880+
p.misalign_offsets = (a_offset << 16) | d_offset;
7881+
7882+
GGML_UNUSED(src0);
7883+
GGML_UNUSED(src2);
7884+
}
7885+
78327886
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
78337887
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
78347888
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@@ -8069,6 +8123,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
80698123

80708124
elements = { OW * KW * KH, OH, batch * IC };
80718125
} break;
8126+
case GGML_OP_IM2COL_3D:
8127+
{
8128+
const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
8129+
8130+
const uint32_t N = ne13 / IC;
8131+
8132+
const uint32_t KD = ne02;
8133+
const uint32_t KH = ne01;
8134+
const uint32_t KW = ne00;
8135+
8136+
const uint32_t OD = ned3 / N;
8137+
const uint32_t OH = ned2;
8138+
const uint32_t OW = ned1;
8139+
8140+
const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
8141+
const uint32_t N_OD_OH = N*OD*OH;
8142+
8143+
elements = { IC_KD_KH_KW, OW, N_OD_OH };
8144+
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
8145+
} break;
80728146
case GGML_OP_TIMESTEP_EMBEDDING:
80738147
{
80748148
const uint32_t dim = dst->op_params[0];
@@ -8225,7 +8299,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
82258299
}
82268300

82278301
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8228-
} else if (op == GGML_OP_IM2COL) {
8302+
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
82298303
// im2col uses only src1 and dst buffers
82308304
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
82318305
} else if (op == GGML_OP_COUNT_EQUAL) {
@@ -9086,6 +9160,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
90869160
}, dryrun);
90879161
}
90889162

9163+
static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9164+
GGML_TENSOR_BINARY_OP_LOCALS
9165+
9166+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
9167+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
9168+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
9169+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
9170+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
9171+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
9172+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
9173+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
9174+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
9175+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
9176+
9177+
const int64_t N = ne13 / IC;
9178+
const int64_t ID = ne12;
9179+
const int64_t IH = ne11;
9180+
const int64_t IW = ne10;
9181+
9182+
const int64_t KD = ne02;
9183+
const int64_t KH = ne01;
9184+
const int64_t KW = ne00;
9185+
9186+
const int64_t OD = ne3 / N;
9187+
const int64_t OH = ne2;
9188+
const int64_t OW = ne1;
9189+
9190+
vk_op_im2col_3d_push_constants pc {};
9191+
9192+
pc.nb10 = nb10 / ggml_type_size(src1->type);
9193+
pc.nb11 = nb11 / ggml_type_size(src1->type);
9194+
pc.nb12 = nb12 / ggml_type_size(src1->type);
9195+
pc.nb13 = nb13 / ggml_type_size(src1->type);
9196+
pc.s0 = s0;
9197+
pc.s1 = s1;
9198+
pc.s2 = s2;
9199+
pc.p0 = p0;
9200+
pc.p1 = p1;
9201+
pc.p2 = p2;
9202+
pc.d0 = d0;
9203+
pc.d1 = d1;
9204+
pc.d2 = d2;
9205+
pc.IW = IW;
9206+
pc.IH = IH;
9207+
pc.ID = ID;
9208+
pc.IC = IC;
9209+
pc.KW = KW;
9210+
pc.OH = OH;
9211+
pc.KD_KH_KW = KD*KH*KW;
9212+
pc.KH_KW = KH*KW;
9213+
pc.IC_KD_KH_KW = IC*KD*KH*KW;
9214+
pc.N_OD_OH = N*OD*OH;
9215+
pc.OD_OH = OD*OH;
9216+
pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
9217+
pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
9218+
pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
9219+
9220+
ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
9221+
}
9222+
90899223
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
90909224
const uint32_t dim = dst->op_params[0];
90919225
const uint32_t max_period = dst->op_params[1];
@@ -10291,6 +10425,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1029110425
case GGML_OP_ARGMAX:
1029210426
case GGML_OP_COUNT_EQUAL:
1029310427
case GGML_OP_IM2COL:
10428+
case GGML_OP_IM2COL_3D:
1029410429
case GGML_OP_TIMESTEP_EMBEDDING:
1029510430
case GGML_OP_CONV_TRANSPOSE_1D:
1029610431
case GGML_OP_POOL_2D:
@@ -10361,6 +10496,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1036110496
case GGML_OP_ARGMAX:
1036210497
case GGML_OP_COUNT_EQUAL:
1036310498
case GGML_OP_IM2COL:
10499+
case GGML_OP_IM2COL_3D:
1036410500
case GGML_OP_TIMESTEP_EMBEDDING:
1036510501
case GGML_OP_CONV_TRANSPOSE_1D:
1036610502
case GGML_OP_POOL_2D:
@@ -10656,6 +10792,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1065610792
case GGML_OP_IM2COL:
1065710793
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
1065810794

10795+
break;
10796+
case GGML_OP_IM2COL_3D:
10797+
ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun);
10798+
1065910799
break;
1066010800
case GGML_OP_TIMESTEP_EMBEDDING:
1066110801
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
@@ -10807,6 +10947,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1080710947
case GGML_OP_ARGMAX:
1080810948
case GGML_OP_COUNT_EQUAL:
1080910949
case GGML_OP_IM2COL:
10950+
case GGML_OP_IM2COL_3D:
1081010951
case GGML_OP_TIMESTEP_EMBEDDING:
1081110952
case GGML_OP_CONV_TRANSPOSE_1D:
1081210953
case GGML_OP_POOL_2D:
@@ -12092,6 +12233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1209212233
case GGML_OP_ARGMAX:
1209312234
case GGML_OP_COUNT_EQUAL:
1209412235
case GGML_OP_IM2COL:
12236+
case GGML_OP_IM2COL_3D:
1209512237
case GGML_OP_TIMESTEP_EMBEDDING:
1209612238
case GGML_OP_CONV_2D_DW:
1209712239
case GGML_OP_POOL_2D:
@@ -12666,6 +12808,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1266612808

1266712809
const bool is_2D = tensor->op_params[6] == 1;
1266812810
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
12811+
} else if (tensor->op == GGML_OP_IM2COL_3D) {
12812+
const int32_t s0 = tensor->op_params[0];
12813+
const int32_t s1 = tensor->op_params[1];
12814+
const int32_t s1 = tensor->op_params[2];
12815+
const int32_t p0 = tensor->op_params[3];
12816+
const int32_t p1 = tensor->op_params[4];
12817+
const int32_t p1 = tensor->op_params[5];
12818+
const int32_t d0 = tensor->op_params[6];
12819+
const int32_t d1 = tensor->op_params[7];
12820+
const int32_t d1 = tensor->op_params[8];
12821+
const int32_t IC = tensor->op_params[9];
12822+
12823+
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
1266912824
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
1267012825
const int32_t dim = tensor->op_params[0];
1267112826
const int32_t max_period = tensor->op_params[1];
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_16bit_storage : require
4+
#extension GL_EXT_control_flow_attributes : require
5+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
6+
7+
#include "rte.comp"
8+
9+
layout (push_constant) uniform parameter
10+
{
11+
uint32_t nb10;
12+
uint32_t nb11;
13+
uint32_t nb12;
14+
uint32_t nb13;
15+
uint32_t s0;
16+
uint32_t s1;
17+
uint32_t s2;
18+
uint32_t p0;
19+
uint32_t p1;
20+
uint32_t p2;
21+
uint32_t d0;
22+
uint32_t d1;
23+
uint32_t d2;
24+
uint32_t IW;
25+
uint32_t IH;
26+
uint32_t ID;
27+
uint32_t IC;
28+
uint32_t KW;
29+
uint32_t OH;
30+
uint32_t KD_KH_KW;
31+
uint32_t KH_KW;
32+
uint32_t IC_KD_KH_KW;
33+
uint32_t N_OD_OH;
34+
uint32_t OD_OH;
35+
uint32_t OD_OH_OW_IC_KD_KH_KW;
36+
uint32_t OH_OW_IC_KD_KH_KW;
37+
uint32_t OW_IC_KD_KH_KW;
38+
uint32_t misalign_offsets;
39+
} p;
40+
41+
#include "types.comp"
42+
43+
uint get_aoffset() { return p.misalign_offsets >> 16; }
44+
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
45+
46+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
47+
48+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
49+
50+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
51+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
52+
53+
void main() {
54+
const uint32_t i = gl_GlobalInvocationID.x;
55+
56+
uint32_t nb10 = p.nb10;
57+
uint32_t nb11 = p.nb11;
58+
uint32_t nb12 = p.nb12;
59+
uint32_t nb13 = p.nb13;
60+
uint32_t s0 = p.s0;
61+
uint32_t s1 = p.s1;
62+
uint32_t s2 = p.s2;
63+
uint32_t p0 = p.p0;
64+
uint32_t p1 = p.p1;
65+
uint32_t p2 = p.p2;
66+
uint32_t d0 = p.d0;
67+
uint32_t d1 = p.d1;
68+
uint32_t d2 = p.d2;
69+
uint32_t IW = p.IW;
70+
uint32_t IH = p.IH;
71+
uint32_t ID = p.ID;
72+
uint32_t IC = p.IC;
73+
uint32_t KW = p.KW;
74+
uint32_t OH = p.OH;
75+
uint32_t KD_KH_KW = p.KD_KH_KW;
76+
uint32_t KH_KW = p.KH_KW;
77+
uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;
78+
uint32_t N_OD_OH = p.N_OD_OH;
79+
uint32_t OD_OH = p.OD_OH;
80+
uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;
81+
uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;
82+
uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;
83+
84+
if (i >= IC_KD_KH_KW) {
85+
return;
86+
}
87+
88+
const uint32_t iic = i / KD_KH_KW;
89+
const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;
90+
const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
91+
const uint32_t ikw = i % KW;
92+
93+
const uint32_t iow = gl_GlobalInvocationID.y;
94+
for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {
95+
const uint32_t in_ = iz / OD_OH;
96+
const uint32_t iod = (iz - in_*OD_OH) / OH;
97+
const uint32_t ioh = iz % OH;
98+
99+
const uint32_t iiw = iow * s0 + ikw * d0 - p0;
100+
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
101+
const uint32_t iid = iod * s2 + ikd * d2 - p2;
102+
103+
const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
104+
105+
if (iih >= IH || iiw >= IW || iid >= ID) {
106+
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
107+
} else {
108+
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
109+
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
110+
}
111+
}
112+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,10 @@ void process_shaders() {
713713
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
714714
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
715715

716+
string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
717+
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
718+
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
719+
716720
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
717721

718722
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});

0 commit comments

Comments
 (0)