Skip to content

Commit bfc73f1

Browse files
committed
sync : ggml (CUDA faster rope)
1 parent f00c9bb commit bfc73f1

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

ggml-cuda.cu

+14-18
Original file line numberDiff line numberDiff line change
@@ -4086,7 +4086,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
40864086
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
40874087
}
40884088

4089-
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
4089+
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
4090+
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
40904091
const int col = blockDim.x*blockIdx.x + threadIdx.x;
40914092
const int half_n_dims = ncols/4;
40924093

@@ -4098,8 +4099,9 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
40984099
const int i = row*ncols + col;
40994100

41004101
const float col_theta_scale = powf(theta_scale, col);
4102+
const float p = p0 + p_delta*(row/p_delta_rows);
41014103

4102-
const float theta = p*col_theta_scale;
4104+
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
41034105
const float sin_theta = sinf(theta);
41044106
const float cos_theta = cosf(theta);
41054107

@@ -4109,7 +4111,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
41094111
dst[i + 0] = x0*cos_theta - x1*sin_theta;
41104112
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
41114113

4112-
const float block_theta = block_p*col_theta_scale;
4114+
const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
41134115
const float sin_block_theta = sinf(block_theta);
41144116
const float cos_block_theta = cosf(block_theta);
41154117

@@ -4984,12 +4986,13 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co
49844986
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
49854987
}
49864988

4987-
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
4988-
GGML_ASSERT(nrows % 4 == 0);
4989-
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
4990-
const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
4989+
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
4990+
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
4991+
GGML_ASSERT(ncols % 4 == 0);
4992+
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
4993+
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
49914994
const dim3 block_nums(num_blocks_x, nrows, 1);
4992-
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
4995+
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
49934996
}
49944997

49954998
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -5723,22 +5726,18 @@ inline void ggml_cuda_op_rope(
57235726
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
57245727

57255728
const float theta_scale = powf(freq_base, -2.0f/n_dims);
5729+
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
57265730

57275731
const bool is_neox = mode & 2;
57285732
const bool is_glm = mode & 4;
57295733

57305734
// compute
57315735
if (is_glm) {
5732-
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
5733-
const float id_p = min(p, n_ctx - 2.f);
5734-
const float block_p = max(p - (n_ctx - 2.f), 0.f);
5735-
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
5736+
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main);
57365737
} else if (is_neox) {
57375738
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
5738-
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
57395739
rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
57405740
} else {
5741-
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
57425741
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
57435742
}
57445743

@@ -6400,10 +6399,7 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
64006399
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
64016400
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
64026401

6403-
const int mode = ((int32_t *) dst->op_params)[2];
6404-
const bool is_glm = mode & 4;
6405-
6406-
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
6402+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true);
64076403
}
64086404

64096405
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

0 commit comments

Comments
 (0)