12
12
// DECOMPRESSION_GROUP_SIZE - group size for weight int4 compression
13
13
// FILTER_LAYOUT_OS_IS_YX_TYPE - 0: OS_IS_YX_OSV16, 1: OS_IS_YX_OSV32_ISV2, 2: OS_IS_YX_OSV64_ISV2
14
14
15
-
16
15
#define KERNEL_LAYOUT_OS_IS_YX_OSV16 (FILTER_LAYOUT_OS_IS_YX_TYPE == 0)
17
16
#define KERNEL_LAYOUT_OS_IS_YX_OSV32_ISV2 (FILTER_LAYOUT_OS_IS_YX_TYPE == 1)
18
17
#define KERNEL_LAYOUT_OS_IS_YX_OSV64_ISV2 (FILTER_LAYOUT_OS_IS_YX_TYPE == 2)
@@ -69,7 +68,7 @@ KERNEL(fully_connected_gpu_gemv)(
69
68
const __global half * scales ,
70
69
# endif
71
70
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
72
- const __global char * zps ,
71
+ const __global half * zps ,
73
72
# endif
74
73
__global half * output ,
75
74
const __global uchar * weights
@@ -113,11 +112,11 @@ KERNEL(fully_connected_gpu_gemv)(
113
112
float scale_1 = convert_float (scales [gk * WEIGHTS_N ]);
114
113
115
114
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
116
- char16 zpx16 = (char16 )(zps [gk * WEIGHTS_N ]);
115
+ half16 zpx16 = (half16 )(zps [gk * WEIGHTS_N ]);
117
116
# elif DECOMPRESSION_ZP_SCALAR
118
- char16 zpx16 = (char16 )(zp_scalar_value );
117
+ half16 zpx16 = (half16 )(zp_scalar_value );
119
118
# else
120
- char16 zpx16 = (char16 )0 ;
119
+ half16 zpx16 = (half16 )0 ;
121
120
# endif
122
121
char16 mask16 = (char16 )0xF ;
123
122
@@ -127,15 +126,15 @@ KERNEL(fully_connected_gpu_gemv)(
127
126
char16 bx16 = as_char16 (intel_sub_group_block_read_uc16 (B ));
128
127
129
128
#if WEI_UINT4
130
- half16 i4x16_even = convert_half16 ((bx16 & mask16 ) - zpx16 ) ;
131
- half16 i4x16_odd = convert_half16 (as_char16 (as_uchar16 (bx16 ) >> 4 ) - zpx16 ) ;
129
+ half16 i4x16_even = convert_half16 ((bx16 & mask16 )) - zpx16 ;
130
+ half16 i4x16_odd = convert_half16 (as_char16 (as_uchar16 (bx16 ) >> 4 )) - zpx16 ;
132
131
#else
133
132
char16 i4x16_even_c16 = (bx16 & (char16 )0xF );
134
133
char16 i4x16_odd_c16 = (as_char16 (as_uchar16 (bx16 ) >> 4 ));
135
- i4x16_even_c16 = select (i4x16_even_c16 , i4x16_even_c16 - (char16 )16 , i4x16_even_c16 > (char16 )7 ) - zpx16 ;
136
- i4x16_odd_c16 = select (i4x16_odd_c16 , i4x16_odd_c16 - (char16 )16 , i4x16_odd_c16 > (char16 )7 ) - zpx16 ;
137
- half16 i4x16_even = convert_half16 (i4x16_even_c16 );
138
- half16 i4x16_odd = convert_half16 (i4x16_odd_c16 );
134
+ i4x16_even_c16 = select (i4x16_even_c16 , i4x16_even_c16 - (char16 )16 , i4x16_even_c16 > (char16 )7 );
135
+ i4x16_odd_c16 = select (i4x16_odd_c16 , i4x16_odd_c16 - (char16 )16 , i4x16_odd_c16 > (char16 )7 );
136
+ half16 i4x16_even = convert_half16 (i4x16_even_c16 ) - zpx16 ;
137
+ half16 i4x16_odd = convert_half16 (i4x16_odd_c16 ) - zpx16 ;
139
138
#endif
140
139
141
140
sum [0 ] += as_half (sub_group_broadcast (input_value .s0 , 0 )) * i4x16_even .s0 +
@@ -211,7 +210,7 @@ KERNEL(fully_connected_gpu_gemv)(
211
210
const __global half * scales ,
212
211
# endif
213
212
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
214
- const __global char * zps ,
213
+ const __global half * zps ,
215
214
# endif
216
215
__global half * output ,
217
216
const __global uchar * weights
@@ -239,7 +238,7 @@ KERNEL(fully_connected_gpu_gemv)(
239
238
__local float all_sum_even [16 ][16 ]; // [wi_id, thr_id]
240
239
__local float all_sum_odd [16 ][16 ];
241
240
242
- // Scale layout is byfx
241
+ // Scale layout is fbyx
243
242
scales += (n / 32 ) * 32 + (n % 32 ) / 2 ;
244
243
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
245
244
zps += (n / 32 ) * 32 + (n % 32 ) / 2 ;
@@ -256,14 +255,13 @@ KERNEL(fully_connected_gpu_gemv)(
256
255
half scale_1 = scales [gk * WEIGHTS_N + 16 ];
257
256
258
257
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
259
- char16 zpx16_0 = (char16 )(zps [gk * WEIGHTS_N ]);
260
- char16 zpx16_1 = (char16 )(zps [gk * WEIGHTS_N + 16 ]);
258
+ half zp0 = zps [gk * WEIGHTS_N ];
259
+ half zp1 = zps [gk * WEIGHTS_N + 16 ];
260
+ half16 zpx16 = {zp0 , zp1 , zp0 , zp1 , zp0 , zp1 , zp0 , zp1 , zp0 , zp1 , zp0 , zp1 , zp0 , zp1 , zp0 , zp1 };
261
261
# elif DECOMPRESSION_ZP_SCALAR
262
- char16 zpx16_0 = (char16 )(zp_scalar_value );
263
- char16 zpx16_1 = (char16 )(zp_scalar_value );
262
+ half16 zpx16 = (half16 )(zp_scalar_value );
264
263
# else
265
- char16 zpx16_0 = (char16 )0 ;
266
- char16 zpx16_1 = (char16 )0 ;
264
+ half16 zpx16 = (half16 )0 ;
267
265
# endif
268
266
char16 mask16 = (char16 )0xF ;
269
267
@@ -276,45 +274,52 @@ KERNEL(fully_connected_gpu_gemv)(
276
274
char16 bx16 = as_char16 (intel_sub_group_block_read_uc16 (B ));
277
275
278
276
#if WEI_UINT4
279
- half16 i4x16_even = convert_half16 (( bx16 & mask16 ) - zpx16_0 ) ;
280
- half16 i4x16_odd = convert_half16 (as_char16 (as_uchar16 (bx16 ) >> 4 ) - zpx16_0 ) ;
277
+ half16 i4x16_even = convert_half16 (bx16 & mask16 ) - zpx16 ;
278
+ half16 i4x16_odd = convert_half16 (as_char16 (as_uchar16 (bx16 ) >> 4 )) - zpx16 ;
281
279
#else
282
280
char16 i4x16_even_c16 = (bx16 & (char16 )0xF );
283
281
char16 i4x16_odd_c16 = (as_char16 (as_uchar16 (bx16 ) >> 4 ));
284
- i4x16_even_c16 = select (i4x16_even_c16 , i4x16_even_c16 - (char16 )16 , i4x16_even_c16 > (char16 )7 ) - zpx16_0 ;
285
- i4x16_odd_c16 = select (i4x16_odd_c16 , i4x16_odd_c16 - (char16 )16 , i4x16_odd_c16 > (char16 )7 ) - zpx16_1 ;
286
- half16 i4x16_even = convert_half16 (i4x16_even_c16 );
287
- half16 i4x16_odd = convert_half16 (i4x16_odd_c16 );
282
+ i4x16_even_c16 = select (i4x16_even_c16 , i4x16_even_c16 - (char16 )16 , i4x16_even_c16 > (char16 )7 );
283
+ i4x16_odd_c16 = select (i4x16_odd_c16 , i4x16_odd_c16 - (char16 )16 , i4x16_odd_c16 > (char16 )7 );
284
+ half16 i4x16_even = convert_half16 (i4x16_even_c16 ) - zpx16 ;
285
+ half16 i4x16_odd = convert_half16 (i4x16_odd_c16 ) - zpx16 ;
288
286
#endif
289
287
290
288
sum [0 ] += as_half (sub_group_broadcast (input_value , 0 )) * i4x16_even .s0 +
291
289
as_half (sub_group_broadcast (input_value , 4 )) * i4x16_even .s4 +
292
290
as_half (sub_group_broadcast (input_value , 8 )) * i4x16_even .s8 +
293
291
as_half (sub_group_broadcast (input_value , 12 )) * i4x16_even .sc ;
292
+
294
293
sum [1 ] += as_half (sub_group_broadcast (input_value , 0 )) * i4x16_even .s1 +
295
294
as_half (sub_group_broadcast (input_value , 4 )) * i4x16_even .s5 +
296
295
as_half (sub_group_broadcast (input_value , 8 )) * i4x16_even .s9 +
297
296
as_half (sub_group_broadcast (input_value , 12 )) * i4x16_even .sd ;
297
+
298
298
sum [2 ] += as_half (sub_group_broadcast (input_value , 1 )) * i4x16_odd .s0 +
299
299
as_half (sub_group_broadcast (input_value , 5 )) * i4x16_odd .s4 +
300
300
as_half (sub_group_broadcast (input_value , 9 )) * i4x16_odd .s8 +
301
301
as_half (sub_group_broadcast (input_value , 13 )) * i4x16_odd .sc ;
302
+
302
303
sum [3 ] += as_half (sub_group_broadcast (input_value , 1 )) * i4x16_odd .s1 +
303
304
as_half (sub_group_broadcast (input_value , 5 )) * i4x16_odd .s5 +
304
305
as_half (sub_group_broadcast (input_value , 9 )) * i4x16_odd .s9 +
305
306
as_half (sub_group_broadcast (input_value , 13 )) * i4x16_odd .sd ;
307
+
306
308
sum [4 ] += as_half (sub_group_broadcast (input_value , 2 )) * i4x16_even .s2 +
307
309
as_half (sub_group_broadcast (input_value , 6 )) * i4x16_even .s6 +
308
310
as_half (sub_group_broadcast (input_value , 10 )) * i4x16_even .sa +
309
311
as_half (sub_group_broadcast (input_value , 14 )) * i4x16_even .se ;
312
+
310
313
sum [5 ] += as_half (sub_group_broadcast (input_value , 2 )) * i4x16_even .s3 +
311
314
as_half (sub_group_broadcast (input_value , 6 )) * i4x16_even .s7 +
312
315
as_half (sub_group_broadcast (input_value , 10 )) * i4x16_even .sb +
313
316
as_half (sub_group_broadcast (input_value , 14 )) * i4x16_even .sf ;
317
+
314
318
sum [6 ] += as_half (sub_group_broadcast (input_value , 3 )) * i4x16_odd .s2 +
315
319
as_half (sub_group_broadcast (input_value , 7 )) * i4x16_odd .s6 +
316
320
as_half (sub_group_broadcast (input_value , 11 )) * i4x16_odd .sa +
317
321
as_half (sub_group_broadcast (input_value , 15 )) * i4x16_odd .se ;
322
+
318
323
sum [7 ] += as_half (sub_group_broadcast (input_value , 3 )) * i4x16_odd .s3 +
319
324
as_half (sub_group_broadcast (input_value , 7 )) * i4x16_odd .s7 +
320
325
as_half (sub_group_broadcast (input_value , 11 )) * i4x16_odd .sb +
@@ -342,7 +347,6 @@ KERNEL(fully_connected_gpu_gemv)(
342
347
# if BIAS_TERM
343
348
sum_value [0 ] += bias [cur_n ];
344
349
sum_value [1 ] += bias [cur_n + 16 ];
345
- // printf("osv-32: idx = %d, bias[%d] = %f, bias[%d] = %f\n", cur_n, cur_n, bias[cur_n], cur_n + 16, bias[cur_n + 16]);
346
350
# endif
347
351
348
352
// fused_op
@@ -366,7 +370,7 @@ KERNEL(fully_connected_gpu_gemv)(
366
370
const __global half * scales ,
367
371
# endif
368
372
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
369
- const __global char * zps ,
373
+ const __global half * zps ,
370
374
# endif
371
375
__global half * output ,
372
376
const __global uchar * weights
@@ -409,21 +413,16 @@ KERNEL(fully_connected_gpu_gemv)(
409
413
half scale_2 = scales [gk * WEIGHTS_N + 2 * 16 ];
410
414
half scale_3 = scales [gk * WEIGHTS_N + 3 * 16 ];
411
415
# if DECOMPRESSION_ZP_TERM && !DECOMPRESSION_ZP_SCALAR
412
- char16 zpx16_0 = (char16 )(zps [k * WEIGHTS_N ]);
413
- char16 zpx16_1 = (char16 )(zps [k * WEIGHTS_N + 1 * 16 ]);
414
- char16 zpx16_2 = (char16 )(zps [k * WEIGHTS_N + 2 * 16 ]);
415
- char16 zpx16_3 = (char16 )(zps [k * WEIGHTS_N + 3 * 16 ]);
416
+ half zp0 = zps [gk * WEIGHTS_N ];
417
+ half zp1 = zps [gk * WEIGHTS_N + 1 * 16 ];
418
+ half zp2 = zps [gk * WEIGHTS_N + 2 * 16 ];
419
+ half zp3 = zps [gk * WEIGHTS_N + 3 * 16 ];
420
+ half16 zpx16 = {zp0 , zp1 , zp2 , zp3 , zp0 , zp1 , zp2 , zp3 , zp0 , zp1 , zp2 , zp3 , zp0 , zp1 , zp2 , zp3 };
416
421
# elif DECOMPRESSION_ZP_SCALAR
417
- char zp_scalar_value = (char )(DECOMPRESSION_ZP_VALUE );
418
- char16 zpx16_0 = (char16 )(zp_scalar_value );
419
- char16 zpx16_1 = (char16 )(zp_scalar_value );
420
- char16 zpx16_2 = (char16 )(zp_scalar_value );
421
- char16 zpx16_3 = (char16 )(zp_scalar_value );
422
+ half zp_scalar_value = (half )(DECOMPRESSION_ZP_VALUE );
423
+ half16 zpx16 = (half16 )(zp_scalar_value );
422
424
# else
423
- char16 zpx16_0 = (char16 )0 ;
424
- char16 zpx16_1 = (char16 )0 ;
425
- char16 zpx16_2 = (char16 )0 ;
426
- char16 zpx16_3 = (char16 )0 ;
425
+ half16 zpx16 = (half16 )0 ;
427
426
# endif
428
427
char16 mask16 = (char16 )0xF ;
429
428
@@ -434,25 +433,25 @@ KERNEL(fully_connected_gpu_gemv)(
434
433
char16 bx16_second = as_char16 (intel_sub_group_block_read_uc16 (B + 16 * 16 ));
435
434
436
435
#if WEI_UINT4
437
- half16 i4x16_even = convert_half16 ((bx16 & mask16 ) - zpx16_0 ) ;
438
- half16 i4x16_odd = convert_half16 (as_char16 (as_uchar16 (bx16 ) >> 4 ) - zpx16_1 ) ;
439
- half16 i4x16_even_second = convert_half16 ((bx16_second & mask16 ) - zpx16_2 ) ;
440
- half16 i4x16_odd_second = convert_half16 (as_char16 (as_uchar16 (bx16_second ) >> 4 ) - zpx16_3 ) ;
436
+ half16 i4x16_even = convert_half16 ((bx16 & mask16 )) - zpx16 ;
437
+ half16 i4x16_odd = convert_half16 (as_char16 (as_uchar16 (bx16 ) >> 4 )) - zpx16 ;
438
+ half16 i4x16_even_second = convert_half16 ((bx16_second & mask16 )) - zpx16 ;
439
+ half16 i4x16_odd_second = convert_half16 (as_char16 (as_uchar16 (bx16_second ) >> 4 )) - zpx16 ;
441
440
#else
442
441
char16 i4x16_even_c16 = (bx16 & (char16 )0xF );
443
442
char16 i4x16_odd_c16 = (as_char16 (as_uchar16 (bx16 ) >> 4 ));
444
- i4x16_even_c16 = select (i4x16_even_c16 , i4x16_even_c16 - (char16 )16 , i4x16_even_c16 > (char16 )7 ) - zpx16_0 ;
445
- i4x16_odd_c16 = select (i4x16_odd_c16 , i4x16_odd_c16 - (char16 )16 , i4x16_odd_c16 > (char16 )7 ) - zpx16_1 ;
443
+ i4x16_even_c16 = select (i4x16_even_c16 , i4x16_even_c16 - (char16 )16 , i4x16_even_c16 > (char16 )7 );
444
+ i4x16_odd_c16 = select (i4x16_odd_c16 , i4x16_odd_c16 - (char16 )16 , i4x16_odd_c16 > (char16 )7 );
446
445
447
446
char16 i4x16_even_c16_second = (bx16_second & (char16 )0xF );
448
447
char16 i4x16_odd_c16_second = (as_char16 (as_uchar16 (bx16_second ) >> 4 ));
449
- i4x16_even_c16_second = select (i4x16_even_c16_second , i4x16_even_c16_second - (char16 )16 , i4x16_even_c16_second > (char16 )7 ) - zpx16_2 ;
450
- i4x16_odd_c16_second = select (i4x16_odd_c16_second , i4x16_odd_c16_second - (char16 )16 , i4x16_odd_c16_second > (char16 )7 ) - zpx16_3 ;
448
+ i4x16_even_c16_second = select (i4x16_even_c16_second , i4x16_even_c16_second - (char16 )16 , i4x16_even_c16_second > (char16 )7 );
449
+ i4x16_odd_c16_second = select (i4x16_odd_c16_second , i4x16_odd_c16_second - (char16 )16 , i4x16_odd_c16_second > (char16 )7 );
451
450
452
- half16 i4x16_even = convert_half16 (i4x16_even_c16 );
453
- half16 i4x16_odd = convert_half16 (i4x16_odd_c16 );
454
- half16 i4x16_even_second = convert_half16 (i4x16_even_c16_second );
455
- half16 i4x16_odd_second = convert_half16 (i4x16_odd_c16_second );
451
+ half16 i4x16_even = convert_half16 (i4x16_even_c16 ) - zpx16 ;
452
+ half16 i4x16_odd = convert_half16 (i4x16_odd_c16 ) - zpx16 ;
453
+ half16 i4x16_even_second = convert_half16 (i4x16_even_c16_second ) - zpx16 ;
454
+ half16 i4x16_odd_second = convert_half16 (i4x16_odd_c16_second ) - zpx16 ;
456
455
#endif
457
456
458
457
sum [0 ] += as_half (sub_group_broadcast (input_value , 0 )) * i4x16_even .s0 +
@@ -521,7 +520,6 @@ KERNEL(fully_connected_gpu_gemv)(
521
520
as_half (sub_group_broadcast (input_value , 15 )) * i4x16_odd_second .sf ;
522
521
}
523
522
524
- // scales applied once
525
523
sum_all [0 ] += (sum [0 ] + sum [4 ]) * scale_0 ;
526
524
sum_all [1 ] += (sum [1 ] + sum [5 ]) * scale_1 ;
527
525
sum_all [2 ] += (sum [2 ] + sum [6 ]) * scale_2 ;
0 commit comments