Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ struct vk_device_struct {
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_conv_transpose_1d_f32;
vk_pipeline pipeline_pool2d_f32;
Expand Down Expand Up @@ -982,6 +983,37 @@ struct vk_op_im2col_push_constants {
int32_t d0; int32_t d1;
};

struct vk_op_im2col_3d_push_constants {
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t s0;
uint32_t s1;
uint32_t s2;
uint32_t p0;
uint32_t p1;
uint32_t p2;
uint32_t d0;
uint32_t d1;
uint32_t d2;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t IC;
uint32_t KW;
uint32_t OH;
uint32_t KD_KH_KW;
uint32_t KH_KW;
uint32_t IC_KD_KH_KW;
uint32_t N_OD_OH;
uint32_t OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW;
uint32_t misalign_offsets;
};

struct vk_op_timestep_embedding_push_constants {
uint32_t nb1;
uint32_t dim;
Expand Down Expand Up @@ -3380,10 +3412,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);

ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
}

ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
Expand Down Expand Up @@ -7717,6 +7752,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_im2col_f32_f16;
}
return nullptr;
case GGML_OP_IM2COL_3D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_im2col_3d_f32;
}
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_im2col_3d_f32_f16;
}
return nullptr;
case GGML_OP_TIMESTEP_EMBEDDING:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_timestep_embedding_f32;
Expand Down Expand Up @@ -7832,6 +7875,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_RMS_NORM:
case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_SET_ROWS:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
Expand Down Expand Up @@ -7890,6 +7934,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src2);
}

template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);

p.misalign_offsets = (a_offset << 16) | d_offset;

GGML_UNUSED(src0);
GGML_UNUSED(src2);
}

template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
Expand Down Expand Up @@ -8130,6 +8184,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co

elements = { OW * KW * KH, OH, batch * IC };
} break;
case GGML_OP_IM2COL_3D:
{
const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];

const uint32_t N = ne13 / IC;

const uint32_t KD = ne02;
const uint32_t KH = ne01;
const uint32_t KW = ne00;

const uint32_t OD = ned3 / N;
const uint32_t OH = ned2;
const uint32_t OW = ned1;

const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
const uint32_t N_OD_OH = N*OD*OH;

elements = { IC_KD_KH_KW, OW, N_OD_OH };
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
case GGML_OP_TIMESTEP_EMBEDDING:
{
const uint32_t dim = dst->op_params[0];
Expand Down Expand Up @@ -8286,7 +8360,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_IM2COL) {
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
// im2col uses only src1 and dst buffers
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) {
Expand Down Expand Up @@ -9147,6 +9221,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
}, dryrun);
}

static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
GGML_TENSOR_BINARY_OP_LOCALS

const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];

const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;

const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;

const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;

vk_op_im2col_3d_push_constants pc {};

pc.nb10 = nb10 / ggml_type_size(src1->type);
pc.nb11 = nb11 / ggml_type_size(src1->type);
pc.nb12 = nb12 / ggml_type_size(src1->type);
pc.nb13 = nb13 / ggml_type_size(src1->type);
pc.s0 = s0;
pc.s1 = s1;
pc.s2 = s2;
pc.p0 = p0;
pc.p1 = p1;
pc.p2 = p2;
pc.d0 = d0;
pc.d1 = d1;
pc.d2 = d2;
pc.IW = IW;
pc.IH = IH;
pc.ID = ID;
pc.IC = IC;
pc.KW = KW;
pc.OH = OH;
pc.KD_KH_KW = KD*KH*KW;
pc.KH_KW = KH*KW;
pc.IC_KD_KH_KW = IC*KD*KH*KW;
pc.N_OD_OH = N*OD*OH;
pc.OD_OH = OD*OH;
pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;

ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
}

static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t dim = dst->op_params[0];
const uint32_t max_period = dst->op_params[1];
Expand Down Expand Up @@ -10352,6 +10486,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
Expand Down Expand Up @@ -10422,6 +10557,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
Expand Down Expand Up @@ -10717,6 +10853,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_IM2COL:
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);

break;
case GGML_OP_IM2COL_3D:
ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun);

break;
case GGML_OP_TIMESTEP_EMBEDDING:
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
Expand Down Expand Up @@ -10868,6 +11008,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
Expand Down Expand Up @@ -12150,6 +12291,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_2D_DW:
case GGML_OP_POOL_2D:
Expand Down Expand Up @@ -12725,6 +12867,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *

const bool is_2D = tensor->op_params[6] == 1;
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
} else if (tensor->op == GGML_OP_IM2COL_3D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t s1 = tensor->op_params[2];
const int32_t p0 = tensor->op_params[3];
const int32_t p1 = tensor->op_params[4];
const int32_t p1 = tensor->op_params[5];
const int32_t d0 = tensor->op_params[6];
const int32_t d1 = tensor->op_params[7];
const int32_t d1 = tensor->op_params[8];
const int32_t IC = tensor->op_params[9];

tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
const int32_t dim = tensor->op_params[0];
const int32_t max_period = tensor->op_params[1];
Expand Down
112 changes: 112 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#version 450

#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "rte.comp"

layout (push_constant) uniform parameter
{
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t s0;
uint32_t s1;
uint32_t s2;
uint32_t p0;
uint32_t p1;
uint32_t p2;
uint32_t d0;
uint32_t d1;
uint32_t d2;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t IC;
uint32_t KW;
uint32_t OH;
uint32_t KD_KH_KW;
uint32_t KH_KW;
uint32_t IC_KD_KH_KW;
uint32_t N_OD_OH;
uint32_t OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW;
uint32_t misalign_offsets;
} p;

#include "types.comp"

uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }

layout(constant_id = 0) const uint BLOCK_SIZE = 32;

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

void main() {
const uint32_t i = gl_GlobalInvocationID.x;

uint32_t nb10 = p.nb10;
uint32_t nb11 = p.nb11;
uint32_t nb12 = p.nb12;
uint32_t nb13 = p.nb13;
uint32_t s0 = p.s0;
uint32_t s1 = p.s1;
uint32_t s2 = p.s2;
uint32_t p0 = p.p0;
uint32_t p1 = p.p1;
uint32_t p2 = p.p2;
uint32_t d0 = p.d0;
uint32_t d1 = p.d1;
uint32_t d2 = p.d2;
uint32_t IW = p.IW;
uint32_t IH = p.IH;
uint32_t ID = p.ID;
uint32_t IC = p.IC;
uint32_t KW = p.KW;
uint32_t OH = p.OH;
uint32_t KD_KH_KW = p.KD_KH_KW;
uint32_t KH_KW = p.KH_KW;
uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;
uint32_t N_OD_OH = p.N_OD_OH;
uint32_t OD_OH = p.OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;

if (i >= IC_KD_KH_KW) {
return;
}

const uint32_t iic = i / KD_KH_KW;
const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const uint32_t ikw = i % KW;

const uint32_t iow = gl_GlobalInvocationID.y;
for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {
const uint32_t in_ = iz / OD_OH;
const uint32_t iod = (iz - in_*OD_OH) / OH;
const uint32_t ioh = iz % OH;

const uint32_t iiw = iow * s0 + ikw * d0 - p0;
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
const uint32_t iid = iod * s2 + ikd * d2 - p2;

const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;

if (iih >= IH || iiw >= IW || iid >= ID) {
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
} else {
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
}
}
}
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,10 @@ void process_shaders() {
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));

string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));

string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
Expand Down
Loading
Loading