Skip to content

Commit acec73a

Browse files
committed
ggml : sync latest ggml + llama.cpp updates (quantization)
1 parent 5cc1741 commit acec73a

File tree

4 files changed

+4200
-1829
lines changed

4 files changed

+4200
-1829
lines changed

ggml-cuda.cu

+365
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
#include <stdint.h>
2+
#include <stdio.h>
3+
#include <cuda_fp16.h>
4+
#include <atomic>
5+
#include "ggml-cuda.h"
6+
7+
typedef uint16_t ggml_fp16_t;
8+
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
9+
10+
#define QK4_0 32
11+
typedef struct {
12+
float d; // delta
13+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
14+
} block_q4_0;
15+
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
16+
17+
#define QK4_1 32
18+
typedef struct {
19+
float d; // delta
20+
float m; // min
21+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
22+
} block_q4_1;
23+
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
24+
25+
#define QK4_2 16
26+
typedef struct {
27+
__half d; // delta
28+
uint8_t qs[QK4_2 / 2]; // nibbles / quants
29+
} block_q4_2;
30+
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
31+
32+
#define QK5_0 32
33+
typedef struct {
34+
__half d; // delta
35+
uint8_t qh[4]; // 5-th bit of quants
36+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
37+
} block_q5_0;
38+
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
39+
40+
#define QK5_1 32
41+
typedef struct {
42+
__half d; // delta
43+
__half m; // min
44+
uint32_t qh; // 5-th bit of quants
45+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
46+
} block_q5_1;
47+
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
48+
49+
#define QK8_0 32
50+
typedef struct {
51+
float d; // delta
52+
int8_t qs[QK8_0]; // quants
53+
} block_q8_0;
54+
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
55+
56+
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
57+
const block_q4_0 * x = (const block_q4_0 *) vx;
58+
59+
const int i = blockIdx.x;
60+
61+
const float d = x[i].d;
62+
63+
const uint8_t * pp = x[i].qs;
64+
65+
for (int l = 0; l < QK4_0; l += 2) {
66+
const uint8_t vi = pp[l/2];
67+
68+
const int8_t vi0 = vi & 0xf;
69+
const int8_t vi1 = vi >> 4;
70+
71+
const float v0 = (vi0 - 8)*d;
72+
const float v1 = (vi1 - 8)*d;
73+
74+
y[i*QK4_0 + l + 0] = v0;
75+
y[i*QK4_0 + l + 1] = v1;
76+
}
77+
}
78+
79+
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
80+
const block_q4_1 * x = (const block_q4_1 *) vx;
81+
82+
const int i = blockIdx.x;
83+
84+
const float d = x[i].d;
85+
const float m = x[i].m;
86+
87+
const uint8_t * pp = x[i].qs;
88+
89+
for (int l = 0; l < QK4_1; l += 2) {
90+
const uint8_t vi = pp[l/2];
91+
92+
const int8_t vi0 = vi & 0xf;
93+
const int8_t vi1 = vi >> 4;
94+
95+
const float v0 = vi0*d + m;
96+
const float v1 = vi1*d + m;
97+
98+
y[i*QK4_1 + l + 0] = v0;
99+
y[i*QK4_1 + l + 1] = v1;
100+
}
101+
}
102+
103+
static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
104+
const block_q4_2 * x = (const block_q4_2 *) vx;
105+
106+
const int i = blockIdx.x;
107+
108+
const float d = x[i].d;
109+
110+
const uint8_t * pp = x[i].qs;
111+
112+
for (int l = 0; l < QK4_2; l += 2) {
113+
const uint8_t vi = pp[l/2];
114+
115+
const int8_t vi0 = vi & 0xf;
116+
const int8_t vi1 = vi >> 4;
117+
118+
const float v0 = (vi0 - 8)*d;
119+
const float v1 = (vi1 - 8)*d;
120+
121+
y[i*QK4_2 + l + 0] = v0;
122+
y[i*QK4_2 + l + 1] = v1;
123+
}
124+
}
125+
126+
static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
127+
const block_q5_0 * x = (const block_q5_0 *) vx;
128+
129+
const int i = blockIdx.x;
130+
131+
const float d = x[i].d;
132+
133+
const uint8_t * pp = x[i].qs;
134+
135+
uint32_t qh;
136+
memcpy(&qh, x[i].qh, sizeof(qh));
137+
138+
for (int l = 0; l < QK5_0; l += 2) {
139+
const uint8_t vi = pp[l/2];
140+
141+
const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
142+
const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
143+
144+
const int8_t vi0 = ((vi & 0xf) | vh0);
145+
const int8_t vi1 = ((vi >> 4) | vh1);
146+
147+
const float v0 = (vi0 - 16)*d;
148+
const float v1 = (vi1 - 16)*d;
149+
150+
y[i*QK5_0 + l + 0] = v0;
151+
y[i*QK5_0 + l + 1] = v1;
152+
}
153+
}
154+
155+
static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
156+
const block_q5_1 * x = (const block_q5_1 *) vx;
157+
158+
const int i = blockIdx.x;
159+
160+
const float d = x[i].d;
161+
const float m = x[i].m;
162+
163+
const uint8_t * pp = x[i].qs;
164+
165+
const uint32_t qh = x[i].qh;
166+
167+
for (int l = 0; l < QK5_1; l += 2) {
168+
const uint8_t vi = pp[l/2];
169+
170+
const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
171+
const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
172+
173+
const int8_t vi0 = (vi & 0xf) | vh0;
174+
const int8_t vi1 = (vi >> 4) | vh1;
175+
176+
const float v0 = vi0*d + m;
177+
const float v1 = vi1*d + m;
178+
179+
y[i*QK5_1 + l + 0] = v0;
180+
y[i*QK5_1 + l + 1] = v1;
181+
}
182+
}
183+
184+
static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
185+
const block_q8_0 * x = (const block_q8_0 *) vx;
186+
187+
const int i = blockIdx.x;
188+
189+
const float d = x[i].d;
190+
191+
const int8_t * pp = x[i].qs;
192+
193+
for (int l = 0; l < QK8_0; l++) {
194+
const int8_t vi = pp[l];
195+
196+
y[i*QK8_0 + l] = vi*d;
197+
}
198+
}
199+
200+
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
201+
const int nb = k / QK4_0;
202+
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
203+
}
204+
205+
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
206+
const int nb = k / QK4_1;
207+
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
208+
}
209+
210+
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
211+
const int nb = k / QK4_2;
212+
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
213+
}
214+
215+
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
216+
const int nb = k / QK5_0;
217+
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
218+
}
219+
220+
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
221+
const int nb = k / QK5_1;
222+
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
223+
}
224+
225+
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
226+
const int nb = k / QK8_0;
227+
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
228+
}
229+
230+
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
231+
switch (type) {
232+
case GGML_TYPE_Q4_0:
233+
return dequantize_row_q4_0_cuda;
234+
case GGML_TYPE_Q4_1:
235+
return dequantize_row_q4_1_cuda;
236+
case GGML_TYPE_Q4_2:
237+
return dequantize_row_q4_2_cuda;
238+
case GGML_TYPE_Q5_0:
239+
return dequantize_row_q5_0_cuda;
240+
case GGML_TYPE_Q5_1:
241+
return dequantize_row_q5_1_cuda;
242+
case GGML_TYPE_Q8_0:
243+
return dequantize_row_q8_0_cuda;
244+
default:
245+
return nullptr;
246+
}
247+
}
248+
249+
// buffer pool for cuda
250+
#define MAX_CUDA_BUFFERS 16
251+
252+
struct scoped_spin_lock {
253+
std::atomic_flag& lock;
254+
scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
255+
while (lock.test_and_set(std::memory_order_acquire)) {
256+
; // spin
257+
}
258+
}
259+
~scoped_spin_lock() {
260+
lock.clear(std::memory_order_release);
261+
}
262+
scoped_spin_lock(const scoped_spin_lock&) = delete;
263+
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
264+
};
265+
266+
struct cuda_buffer {
267+
void * ptr = nullptr;
268+
size_t size = 0;
269+
};
270+
271+
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
272+
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
273+
274+
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
275+
scoped_spin_lock lock(g_cuda_pool_lock);
276+
277+
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
278+
cuda_buffer& b = g_cuda_buffer_pool[i];
279+
if (b.size >= size && b.ptr != nullptr) {
280+
void * ptr = b.ptr;
281+
*actual_size = b.size;
282+
b.ptr = nullptr;
283+
b.size = 0;
284+
return ptr;
285+
}
286+
}
287+
void * ptr;
288+
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
289+
*actual_size = size;
290+
return ptr;
291+
}
292+
293+
void ggml_cuda_pool_free(void * ptr, size_t size) {
294+
scoped_spin_lock lock(g_cuda_pool_lock);
295+
296+
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
297+
cuda_buffer& b = g_cuda_buffer_pool[i];
298+
if (b.ptr == nullptr) {
299+
b.ptr = ptr;
300+
b.size = size;
301+
return;
302+
}
303+
}
304+
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
305+
CUDA_CHECK(cudaFree(ptr));
306+
}
307+
308+
cublasHandle_t g_cublasH = nullptr;
309+
cudaStream_t g_cudaStream = nullptr;
310+
cudaStream_t g_cudaStream2 = nullptr;
311+
cudaEvent_t g_cudaEvent = nullptr;
312+
313+
void ggml_init_cublas() {
314+
if (g_cublasH == nullptr) {
315+
// create cublas handle, bind a stream
316+
CUBLAS_CHECK(cublasCreate(&g_cublasH));
317+
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
318+
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
319+
320+
// create additional stream and event for synchronization
321+
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
322+
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
323+
324+
// configure logging to stdout
325+
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
326+
}
327+
}
328+
329+
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
330+
const uint64_t ne0 = src->ne[0];
331+
const uint64_t ne1 = src->ne[1];
332+
const uint64_t nb0 = src->nb[0];
333+
const uint64_t nb1 = src->nb[1];
334+
const uint64_t nb2 = src->nb[2];
335+
const uint64_t nb3 = src->nb[3];
336+
const enum ggml_type type = src->type;
337+
const size_t ts = ggml_type_size(type);
338+
const size_t bs = ggml_blck_size(type);
339+
340+
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
341+
if (nb0 == ts && nb1 == ts*ne0/bs) {
342+
return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
343+
} else if (nb0 == ts) {
344+
return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
345+
} else {
346+
for (uint64_t i1 = 0; i1 < ne1; i1++) {
347+
const void * rx = (const void *) ((const char *) x + i1*nb1);
348+
void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
349+
// pretend the row is a matrix with cols=1
350+
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
351+
if (r != cudaSuccess) return r;
352+
}
353+
return cudaSuccess;
354+
}
355+
}
356+
357+
void * ggml_cuda_host_malloc(size_t size) {
358+
void * ptr;
359+
CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
360+
return ptr;
361+
}
362+
363+
void ggml_cuda_host_free(void * ptr) {
364+
CUDA_CHECK(cudaFreeHost(ptr));
365+
}

0 commit comments

Comments
 (0)