@@ -554,6 +554,7 @@ struct vk_device_struct {
554
554
vk_pipeline pipeline_argmax_f32;
555
555
vk_pipeline pipeline_count_equal_i32;
556
556
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
557
+ vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
557
558
vk_pipeline pipeline_timestep_embedding_f32;
558
559
vk_pipeline pipeline_conv_transpose_1d_f32;
559
560
vk_pipeline pipeline_pool2d_f32;
@@ -931,6 +932,37 @@ struct vk_op_im2col_push_constants {
931
932
int32_t d0; int32_t d1;
932
933
};
933
934
935
+ struct vk_op_im2col_3d_push_constants {
936
+ uint32_t nb10;
937
+ uint32_t nb11;
938
+ uint32_t nb12;
939
+ uint32_t nb13;
940
+ uint32_t s0;
941
+ uint32_t s1;
942
+ uint32_t s2;
943
+ uint32_t p0;
944
+ uint32_t p1;
945
+ uint32_t p2;
946
+ uint32_t d0;
947
+ uint32_t d1;
948
+ uint32_t d2;
949
+ uint32_t IW;
950
+ uint32_t IH;
951
+ uint32_t ID;
952
+ uint32_t IC;
953
+ uint32_t KW;
954
+ uint32_t OH;
955
+ uint32_t KD_KH_KW;
956
+ uint32_t KH_KW;
957
+ uint32_t IC_KD_KH_KW;
958
+ uint32_t N_OD_OH;
959
+ uint32_t OD_OH;
960
+ uint32_t OD_OH_OW_IC_KD_KH_KW;
961
+ uint32_t OH_OW_IC_KD_KH_KW;
962
+ uint32_t OW_IC_KD_KH_KW;
963
+ uint32_t misalign_offsets;
964
+ };
965
+
934
966
struct vk_op_timestep_embedding_push_constants {
935
967
uint32_t nb1;
936
968
uint32_t dim;
@@ -3329,10 +3361,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
3329
3361
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
3330
3362
3331
3363
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3364
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3332
3365
if (device->float_controls_rte_fp16) {
3333
3366
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3367
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3334
3368
} else {
3335
3369
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3370
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3336
3371
}
3337
3372
3338
3373
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -7666,6 +7701,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7666
7701
return ctx->device->pipeline_im2col_f32_f16;
7667
7702
}
7668
7703
return nullptr;
7704
+ case GGML_OP_IM2COL_3D:
7705
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7706
+ return ctx->device->pipeline_im2col_3d_f32;
7707
+ }
7708
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
7709
+ return ctx->device->pipeline_im2col_3d_f32_f16;
7710
+ }
7711
+ return nullptr;
7669
7712
case GGML_OP_TIMESTEP_EMBEDDING:
7670
7713
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7671
7714
return ctx->device->pipeline_timestep_embedding_f32;
@@ -7781,6 +7824,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
7781
7824
case GGML_OP_RMS_NORM:
7782
7825
case GGML_OP_CONV_2D_DW:
7783
7826
case GGML_OP_IM2COL:
7827
+ case GGML_OP_IM2COL_3D:
7784
7828
case GGML_OP_SET_ROWS:
7785
7829
case GGML_OP_SUM:
7786
7830
case GGML_OP_SUM_ROWS:
@@ -7829,6 +7873,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
7829
7873
GGML_UNUSED(src2);
7830
7874
}
7831
7875
7876
+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7877
+ const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
7878
+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
7879
+
7880
+ p.misalign_offsets = (a_offset << 16) | d_offset;
7881
+
7882
+ GGML_UNUSED(src0);
7883
+ GGML_UNUSED(src2);
7884
+ }
7885
+
7832
7886
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7833
7887
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
7834
7888
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@@ -8069,6 +8123,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
8069
8123
8070
8124
elements = { OW * KW * KH, OH, batch * IC };
8071
8125
} break;
8126
+ case GGML_OP_IM2COL_3D:
8127
+ {
8128
+ const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
8129
+
8130
+ const uint32_t N = ne13 / IC;
8131
+
8132
+ const uint32_t KD = ne02;
8133
+ const uint32_t KH = ne01;
8134
+ const uint32_t KW = ne00;
8135
+
8136
+ const uint32_t OD = ned3 / N;
8137
+ const uint32_t OH = ned2;
8138
+ const uint32_t OW = ned1;
8139
+
8140
+ const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
8141
+ const uint32_t N_OD_OH = N*OD*OH;
8142
+
8143
+ elements = { IC_KD_KH_KW, OW, N_OD_OH };
8144
+ elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
8145
+ } break;
8072
8146
case GGML_OP_TIMESTEP_EMBEDDING:
8073
8147
{
8074
8148
const uint32_t dim = dst->op_params[0];
@@ -8225,7 +8299,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
8225
8299
}
8226
8300
8227
8301
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8228
- } else if (op == GGML_OP_IM2COL) {
8302
+ } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D ) {
8229
8303
// im2col uses only src1 and dst buffers
8230
8304
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8231
8305
} else if (op == GGML_OP_COUNT_EQUAL) {
@@ -9086,6 +9160,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
9086
9160
}, dryrun);
9087
9161
}
9088
9162
9163
+ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9164
+ GGML_TENSOR_BINARY_OP_LOCALS
9165
+
9166
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
9167
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
9168
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
9169
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
9170
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
9171
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
9172
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
9173
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
9174
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
9175
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
9176
+
9177
+ const int64_t N = ne13 / IC;
9178
+ const int64_t ID = ne12;
9179
+ const int64_t IH = ne11;
9180
+ const int64_t IW = ne10;
9181
+
9182
+ const int64_t KD = ne02;
9183
+ const int64_t KH = ne01;
9184
+ const int64_t KW = ne00;
9185
+
9186
+ const int64_t OD = ne3 / N;
9187
+ const int64_t OH = ne2;
9188
+ const int64_t OW = ne1;
9189
+
9190
+ vk_op_im2col_3d_push_constants pc {};
9191
+
9192
+ pc.nb10 = nb10 / ggml_type_size(src1->type);
9193
+ pc.nb11 = nb11 / ggml_type_size(src1->type);
9194
+ pc.nb12 = nb12 / ggml_type_size(src1->type);
9195
+ pc.nb13 = nb13 / ggml_type_size(src1->type);
9196
+ pc.s0 = s0;
9197
+ pc.s1 = s1;
9198
+ pc.s2 = s2;
9199
+ pc.p0 = p0;
9200
+ pc.p1 = p1;
9201
+ pc.p2 = p2;
9202
+ pc.d0 = d0;
9203
+ pc.d1 = d1;
9204
+ pc.d2 = d2;
9205
+ pc.IW = IW;
9206
+ pc.IH = IH;
9207
+ pc.ID = ID;
9208
+ pc.IC = IC;
9209
+ pc.KW = KW;
9210
+ pc.OH = OH;
9211
+ pc.KD_KH_KW = KD*KH*KW;
9212
+ pc.KH_KW = KH*KW;
9213
+ pc.IC_KD_KH_KW = IC*KD*KH*KW;
9214
+ pc.N_OD_OH = N*OD*OH;
9215
+ pc.OD_OH = OD*OH;
9216
+ pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
9217
+ pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
9218
+ pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
9219
+
9220
+ ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
9221
+ }
9222
+
9089
9223
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9090
9224
const uint32_t dim = dst->op_params[0];
9091
9225
const uint32_t max_period = dst->op_params[1];
@@ -10291,6 +10425,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10291
10425
case GGML_OP_ARGMAX:
10292
10426
case GGML_OP_COUNT_EQUAL:
10293
10427
case GGML_OP_IM2COL:
10428
+ case GGML_OP_IM2COL_3D:
10294
10429
case GGML_OP_TIMESTEP_EMBEDDING:
10295
10430
case GGML_OP_CONV_TRANSPOSE_1D:
10296
10431
case GGML_OP_POOL_2D:
@@ -10361,6 +10496,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10361
10496
case GGML_OP_ARGMAX:
10362
10497
case GGML_OP_COUNT_EQUAL:
10363
10498
case GGML_OP_IM2COL:
10499
+ case GGML_OP_IM2COL_3D:
10364
10500
case GGML_OP_TIMESTEP_EMBEDDING:
10365
10501
case GGML_OP_CONV_TRANSPOSE_1D:
10366
10502
case GGML_OP_POOL_2D:
@@ -10656,6 +10792,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10656
10792
case GGML_OP_IM2COL:
10657
10793
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
10658
10794
10795
+ break;
10796
+ case GGML_OP_IM2COL_3D:
10797
+ ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun);
10798
+
10659
10799
break;
10660
10800
case GGML_OP_TIMESTEP_EMBEDDING:
10661
10801
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
@@ -10807,6 +10947,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
10807
10947
case GGML_OP_ARGMAX:
10808
10948
case GGML_OP_COUNT_EQUAL:
10809
10949
case GGML_OP_IM2COL:
10950
+ case GGML_OP_IM2COL_3D:
10810
10951
case GGML_OP_TIMESTEP_EMBEDDING:
10811
10952
case GGML_OP_CONV_TRANSPOSE_1D:
10812
10953
case GGML_OP_POOL_2D:
@@ -12092,6 +12233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
12092
12233
case GGML_OP_ARGMAX:
12093
12234
case GGML_OP_COUNT_EQUAL:
12094
12235
case GGML_OP_IM2COL:
12236
+ case GGML_OP_IM2COL_3D:
12095
12237
case GGML_OP_TIMESTEP_EMBEDDING:
12096
12238
case GGML_OP_CONV_2D_DW:
12097
12239
case GGML_OP_POOL_2D:
@@ -12666,6 +12808,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
12666
12808
12667
12809
const bool is_2D = tensor->op_params[6] == 1;
12668
12810
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
12811
+ } else if (tensor->op == GGML_OP_IM2COL_3D) {
12812
+ const int32_t s0 = tensor->op_params[0];
12813
+ const int32_t s1 = tensor->op_params[1];
12814
+ const int32_t s1 = tensor->op_params[2];
12815
+ const int32_t p0 = tensor->op_params[3];
12816
+ const int32_t p1 = tensor->op_params[4];
12817
+ const int32_t p1 = tensor->op_params[5];
12818
+ const int32_t d0 = tensor->op_params[6];
12819
+ const int32_t d1 = tensor->op_params[7];
12820
+ const int32_t d1 = tensor->op_params[8];
12821
+ const int32_t IC = tensor->op_params[9];
12822
+
12823
+ tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
12669
12824
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
12670
12825
const int32_t dim = tensor->op_params[0];
12671
12826
const int32_t max_period = tensor->op_params[1];
0 commit comments