Skip to content

Commit 7e63d3e

Browse files
committed
Fix zp issue
1 parent d1d7c53 commit 7e63d3e

File tree

2 files changed

+314
-59
lines changed

2 files changed

+314
-59
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_gemv.cl

+51-53
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// DECOMPRESSION_GROUP_SIZE - group size for weight int4 compression
1313
// FILTER_LAYOUT_OS_IS_YX_TYPE - 0: OS_IS_YX_OSV16, 1: OS_IS_YX_OSV32_ISV2, 2: OS_IS_YX_OSV64_ISV2
1414

15-
1615
#define KERNEL_LAYOUT_OS_IS_YX_OSV16 (FILTER_LAYOUT_OS_IS_YX_TYPE == 0)
1716
#define KERNEL_LAYOUT_OS_IS_YX_OSV32_ISV2 (FILTER_LAYOUT_OS_IS_YX_TYPE == 1)
1817
#define KERNEL_LAYOUT_OS_IS_YX_OSV64_ISV2 (FILTER_LAYOUT_OS_IS_YX_TYPE == 2)
@@ -69,7 +68,7 @@ KERNEL(fully_connected_gpu_gemv)(
6968
const __global half* scales,
7069
# endif
7170
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
72-
const __global char* zps,
71+
const __global half* zps,
7372
# endif
7473
__global half* output,
7574
const __global uchar* weights
@@ -113,11 +112,11 @@ KERNEL(fully_connected_gpu_gemv)(
113112
float scale_1 = convert_float(scales[gk * WEIGHTS_N]);
114113

115114
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
116-
char16 zpx16 = (char16)(zps[gk * WEIGHTS_N]);
115+
half16 zpx16 = (half16)(zps[gk * WEIGHTS_N]);
117116
# elif DECOMPRESSION_ZP_SCALAR
118-
char16 zpx16 = (char16)(zp_scalar_value);
117+
half16 zpx16 = (half16)(zp_scalar_value);
119118
# else
120-
char16 zpx16 = (char16)0;
119+
half16 zpx16 = (half16)0;
121120
# endif
122121
char16 mask16 = (char16)0xF;
123122

@@ -127,15 +126,15 @@ KERNEL(fully_connected_gpu_gemv)(
127126
char16 bx16 = as_char16(intel_sub_group_block_read_uc16(B));
128127

129128
#if WEI_UINT4
130-
half16 i4x16_even = convert_half16((bx16 & mask16) - zpx16);
131-
half16 i4x16_odd = convert_half16(as_char16(as_uchar16(bx16) >> 4) - zpx16);
129+
half16 i4x16_even = convert_half16((bx16 & mask16)) - zpx16;
130+
half16 i4x16_odd = convert_half16(as_char16(as_uchar16(bx16) >> 4)) - zpx16;
132131
#else
133132
char16 i4x16_even_c16 = (bx16 & (char16)0xF);
134133
char16 i4x16_odd_c16 = (as_char16(as_uchar16(bx16) >> 4));
135-
i4x16_even_c16 = select(i4x16_even_c16, i4x16_even_c16 - (char16)16, i4x16_even_c16 > (char16)7) - zpx16;
136-
i4x16_odd_c16 = select(i4x16_odd_c16, i4x16_odd_c16 - (char16)16, i4x16_odd_c16 > (char16)7) - zpx16;
137-
half16 i4x16_even = convert_half16(i4x16_even_c16);
138-
half16 i4x16_odd = convert_half16(i4x16_odd_c16);
134+
i4x16_even_c16 = select(i4x16_even_c16, i4x16_even_c16 - (char16)16, i4x16_even_c16 > (char16)7);
135+
i4x16_odd_c16 = select(i4x16_odd_c16, i4x16_odd_c16 - (char16)16, i4x16_odd_c16 > (char16)7);
136+
half16 i4x16_even = convert_half16(i4x16_even_c16) - zpx16;
137+
half16 i4x16_odd = convert_half16(i4x16_odd_c16) - zpx16;
139138
#endif
140139

141140
sum[0] += as_half(sub_group_broadcast(input_value.s0, 0)) * i4x16_even.s0 +
@@ -211,7 +210,7 @@ KERNEL(fully_connected_gpu_gemv)(
211210
const __global half* scales,
212211
# endif
213212
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
214-
const __global char* zps,
213+
const __global half* zps,
215214
# endif
216215
__global half* output,
217216
const __global uchar* weights
@@ -239,7 +238,7 @@ KERNEL(fully_connected_gpu_gemv)(
239238
__local float all_sum_even[16][16]; // [wi_id, thr_id]
240239
__local float all_sum_odd[16][16];
241240

242-
// Scale layout is byfx
241+
// Scale layout is fbyx
243242
scales += (n / 32) * 32 + (n % 32) / 2;
244243
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
245244
zps += (n / 32) * 32 + (n % 32) / 2;
@@ -256,14 +255,13 @@ KERNEL(fully_connected_gpu_gemv)(
256255
half scale_1 = scales[gk * WEIGHTS_N + 16];
257256

258257
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
259-
char16 zpx16_0 = (char16)(zps[gk * WEIGHTS_N]);
260-
char16 zpx16_1 = (char16)(zps[gk * WEIGHTS_N + 16]);
258+
half zp0 = zps[gk * WEIGHTS_N];
259+
half zp1 = zps[gk * WEIGHTS_N + 16];
260+
half16 zpx16 = {zp0, zp1, zp0, zp1, zp0, zp1, zp0, zp1, zp0, zp1, zp0, zp1, zp0, zp1, zp0, zp1};
261261
# elif DECOMPRESSION_ZP_SCALAR
262-
char16 zpx16_0 = (char16)(zp_scalar_value);
263-
char16 zpx16_1 = (char16)(zp_scalar_value);
262+
half16 zpx16 = (half16)(zp_scalar_value);
264263
# else
265-
char16 zpx16_0 = (char16)0;
266-
char16 zpx16_1 = (char16)0;
264+
half16 zpx16 = (half16)0;
267265
# endif
268266
char16 mask16 = (char16)0xF;
269267

@@ -276,45 +274,52 @@ KERNEL(fully_connected_gpu_gemv)(
276274
char16 bx16 = as_char16(intel_sub_group_block_read_uc16(B));
277275

278276
#if WEI_UINT4
279-
half16 i4x16_even = convert_half16((bx16 & mask16) - zpx16_0);
280-
half16 i4x16_odd = convert_half16(as_char16(as_uchar16(bx16) >> 4) - zpx16_0);
277+
half16 i4x16_even = convert_half16(bx16 & mask16) - zpx16;
278+
half16 i4x16_odd = convert_half16(as_char16(as_uchar16(bx16) >> 4)) - zpx16;
281279
#else
282280
char16 i4x16_even_c16 = (bx16 & (char16)0xF);
283281
char16 i4x16_odd_c16 = (as_char16(as_uchar16(bx16) >> 4));
284-
i4x16_even_c16 = select(i4x16_even_c16, i4x16_even_c16 - (char16)16, i4x16_even_c16 > (char16)7) - zpx16_0;
285-
i4x16_odd_c16 = select(i4x16_odd_c16, i4x16_odd_c16 - (char16)16, i4x16_odd_c16 > (char16)7) - zpx16_1;
286-
half16 i4x16_even = convert_half16(i4x16_even_c16);
287-
half16 i4x16_odd = convert_half16(i4x16_odd_c16);
282+
i4x16_even_c16 = select(i4x16_even_c16, i4x16_even_c16 - (char16)16, i4x16_even_c16 > (char16)7);
283+
i4x16_odd_c16 = select(i4x16_odd_c16, i4x16_odd_c16 - (char16)16, i4x16_odd_c16 > (char16)7);
284+
half16 i4x16_even = convert_half16(i4x16_even_c16) - zpx16;
285+
half16 i4x16_odd = convert_half16(i4x16_odd_c16) - zpx16;
288286
#endif
289287

290288
sum[0] += as_half(sub_group_broadcast(input_value, 0)) * i4x16_even.s0 +
291289
as_half(sub_group_broadcast(input_value, 4)) * i4x16_even.s4 +
292290
as_half(sub_group_broadcast(input_value, 8)) * i4x16_even.s8 +
293291
as_half(sub_group_broadcast(input_value, 12)) * i4x16_even.sc;
292+
294293
sum[1] += as_half(sub_group_broadcast(input_value, 0)) * i4x16_even.s1 +
295294
as_half(sub_group_broadcast(input_value, 4)) * i4x16_even.s5 +
296295
as_half(sub_group_broadcast(input_value, 8)) * i4x16_even.s9 +
297296
as_half(sub_group_broadcast(input_value, 12)) * i4x16_even.sd;
297+
298298
sum[2] += as_half(sub_group_broadcast(input_value, 1)) * i4x16_odd.s0 +
299299
as_half(sub_group_broadcast(input_value, 5)) * i4x16_odd.s4 +
300300
as_half(sub_group_broadcast(input_value, 9)) * i4x16_odd.s8 +
301301
as_half(sub_group_broadcast(input_value, 13)) * i4x16_odd.sc;
302+
302303
sum[3] += as_half(sub_group_broadcast(input_value, 1)) * i4x16_odd.s1 +
303304
as_half(sub_group_broadcast(input_value, 5)) * i4x16_odd.s5 +
304305
as_half(sub_group_broadcast(input_value, 9)) * i4x16_odd.s9 +
305306
as_half(sub_group_broadcast(input_value, 13)) * i4x16_odd.sd;
307+
306308
sum[4] += as_half(sub_group_broadcast(input_value, 2)) * i4x16_even.s2 +
307309
as_half(sub_group_broadcast(input_value, 6)) * i4x16_even.s6 +
308310
as_half(sub_group_broadcast(input_value, 10)) * i4x16_even.sa +
309311
as_half(sub_group_broadcast(input_value, 14)) * i4x16_even.se;
312+
310313
sum[5] += as_half(sub_group_broadcast(input_value, 2)) * i4x16_even.s3 +
311314
as_half(sub_group_broadcast(input_value, 6)) * i4x16_even.s7 +
312315
as_half(sub_group_broadcast(input_value, 10)) * i4x16_even.sb +
313316
as_half(sub_group_broadcast(input_value, 14)) * i4x16_even.sf;
317+
314318
sum[6] += as_half(sub_group_broadcast(input_value, 3)) * i4x16_odd.s2 +
315319
as_half(sub_group_broadcast(input_value, 7)) * i4x16_odd.s6 +
316320
as_half(sub_group_broadcast(input_value, 11)) * i4x16_odd.sa +
317321
as_half(sub_group_broadcast(input_value, 15)) * i4x16_odd.se;
322+
318323
sum[7] += as_half(sub_group_broadcast(input_value, 3)) * i4x16_odd.s3 +
319324
as_half(sub_group_broadcast(input_value, 7)) * i4x16_odd.s7 +
320325
as_half(sub_group_broadcast(input_value, 11)) * i4x16_odd.sb +
@@ -342,7 +347,6 @@ KERNEL(fully_connected_gpu_gemv)(
342347
# if BIAS_TERM
343348
sum_value[0] += bias[cur_n];
344349
sum_value[1] += bias[cur_n + 16];
345-
// printf("osv-32: idx = %d, bias[%d] = %f, bias[%d] = %f\n", cur_n, cur_n, bias[cur_n], cur_n + 16, bias[cur_n + 16]);
346350
# endif
347351

348352
// fused_op
@@ -366,7 +370,7 @@ KERNEL(fully_connected_gpu_gemv)(
366370
const __global half* scales,
367371
# endif
368372
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
369-
const __global char* zps,
373+
const __global half* zps,
370374
# endif
371375
__global half* output,
372376
const __global uchar* weights
@@ -409,21 +413,16 @@ KERNEL(fully_connected_gpu_gemv)(
409413
half scale_2 = scales[gk * WEIGHTS_N + 2 * 16];
410414
half scale_3 = scales[gk * WEIGHTS_N + 3 * 16];
411415
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
412-
char16 zpx16_0 = (char16)(zps[k * WEIGHTS_N]);
413-
char16 zpx16_1 = (char16)(zps[k * WEIGHTS_N + 1 * 16]);
414-
char16 zpx16_2 = (char16)(zps[k * WEIGHTS_N + 2 * 16]);
415-
char16 zpx16_3 = (char16)(zps[k * WEIGHTS_N + 3 * 16]);
416+
half zp0 = zps[gk * WEIGHTS_N];
417+
half zp1 = zps[gk * WEIGHTS_N + 1 * 16];
418+
half zp2 = zps[gk * WEIGHTS_N + 2 * 16];
419+
half zp3 = zps[gk * WEIGHTS_N + 3 * 16];
420+
half16 zpx16 = {zp0, zp1, zp2, zp3, zp0, zp1, zp2, zp3, zp0, zp1, zp2, zp3, zp0, zp1, zp2, zp3};
416421
# elif DECOMPRESSION_ZP_SCALAR
417-
char zp_scalar_value = (char)(DECOMPRESSION_ZP_VALUE);
418-
char16 zpx16_0 = (char16)(zp_scalar_value);
419-
char16 zpx16_1 = (char16)(zp_scalar_value);
420-
char16 zpx16_2 = (char16)(zp_scalar_value);
421-
char16 zpx16_3 = (char16)(zp_scalar_value);
422+
half zp_scalar_value = (half)(DECOMPRESSION_ZP_VALUE);
423+
half16 zpx16 = (half16)(zp_scalar_value);
422424
# else
423-
char16 zpx16_0 = (char16)0;
424-
char16 zpx16_1 = (char16)0;
425-
char16 zpx16_2 = (char16)0;
426-
char16 zpx16_3 = (char16)0;
425+
half16 zpx16 = (half16)0;
427426
# endif
428427
char16 mask16 = (char16)0xF;
429428

@@ -434,25 +433,25 @@ KERNEL(fully_connected_gpu_gemv)(
434433
char16 bx16_second = as_char16(intel_sub_group_block_read_uc16(B + 16 * 16));
435434

436435
#if WEI_UINT4
437-
half16 i4x16_even = convert_half16((bx16 & mask16) - zpx16_0);
438-
half16 i4x16_odd = convert_half16(as_char16(as_uchar16(bx16) >> 4) - zpx16_1);
439-
half16 i4x16_even_second = convert_half16((bx16_second & mask16) - zpx16_2);
440-
half16 i4x16_odd_second = convert_half16(as_char16(as_uchar16(bx16_second) >> 4) - zpx16_3);
436+
half16 i4x16_even = convert_half16((bx16 & mask16)) - zpx16;
437+
half16 i4x16_odd = convert_half16(as_char16(as_uchar16(bx16) >> 4)) - zpx16;
438+
half16 i4x16_even_second = convert_half16((bx16_second & mask16)) - zpx16;
439+
half16 i4x16_odd_second = convert_half16(as_char16(as_uchar16(bx16_second) >> 4)) - zpx16;
441440
#else
442441
char16 i4x16_even_c16 = (bx16 & (char16)0xF);
443442
char16 i4x16_odd_c16 = (as_char16(as_uchar16(bx16) >> 4));
444-
i4x16_even_c16 = select(i4x16_even_c16, i4x16_even_c16 - (char16)16, i4x16_even_c16 > (char16)7) - zpx16_0;
445-
i4x16_odd_c16 = select(i4x16_odd_c16, i4x16_odd_c16 - (char16)16, i4x16_odd_c16 > (char16)7) - zpx16_1;
443+
i4x16_even_c16 = select(i4x16_even_c16, i4x16_even_c16 - (char16)16, i4x16_even_c16 > (char16)7);
444+
i4x16_odd_c16 = select(i4x16_odd_c16, i4x16_odd_c16 - (char16)16, i4x16_odd_c16 > (char16)7);
446445

447446
char16 i4x16_even_c16_second = (bx16_second & (char16)0xF);
448447
char16 i4x16_odd_c16_second = (as_char16(as_uchar16(bx16_second) >> 4));
449-
i4x16_even_c16_second = select(i4x16_even_c16_second, i4x16_even_c16_second - (char16)16, i4x16_even_c16_second > (char16)7) - zpx16_2;
450-
i4x16_odd_c16_second = select(i4x16_odd_c16_second, i4x16_odd_c16_second - (char16)16, i4x16_odd_c16_second > (char16)7) - zpx16_3;
448+
i4x16_even_c16_second = select(i4x16_even_c16_second, i4x16_even_c16_second - (char16)16, i4x16_even_c16_second > (char16)7);
449+
i4x16_odd_c16_second = select(i4x16_odd_c16_second, i4x16_odd_c16_second - (char16)16, i4x16_odd_c16_second > (char16)7);
451450

452-
half16 i4x16_even = convert_half16(i4x16_even_c16);
453-
half16 i4x16_odd = convert_half16(i4x16_odd_c16);
454-
half16 i4x16_even_second = convert_half16(i4x16_even_c16_second);
455-
half16 i4x16_odd_second = convert_half16(i4x16_odd_c16_second);
451+
half16 i4x16_even = convert_half16(i4x16_even_c16) - zpx16;
452+
half16 i4x16_odd = convert_half16(i4x16_odd_c16) - zpx16;
453+
half16 i4x16_even_second = convert_half16(i4x16_even_c16_second) - zpx16;
454+
half16 i4x16_odd_second = convert_half16(i4x16_odd_c16_second) - zpx16;
456455
#endif
457456

458457
sum[0] += as_half(sub_group_broadcast(input_value, 0)) * i4x16_even.s0 +
@@ -521,7 +520,6 @@ KERNEL(fully_connected_gpu_gemv)(
521520
as_half(sub_group_broadcast(input_value, 15)) * i4x16_odd_second.sf;
522521
}
523522

524-
// scales applied once
525523
sum_all[0] += (sum[0] + sum[4]) * scale_0;
526524
sum_all[1] += (sum[1] + sum[5]) * scale_1;
527525
sum_all[2] += (sum[2] + sum[6]) * scale_2;

0 commit comments

Comments
 (0)