@@ -369,6 +369,57 @@ void mvm_kernel_avx2_16_f16c( float const * mat, float const * vec, size_t width
369
369
_mm256_store_ps ( rdi + 8 , acc1 );
370
370
}
371
371
372
+ void mvm_kernel_avx2_24 ( float const * mat, float const * vec, size_t width, float * rdi )
373
+ {
374
+ __m256 acc0 = _mm256_setzero_ps ();
375
+ __m256 acc1 = _mm256_setzero_ps ();
376
+ __m256 acc2 = _mm256_setzero_ps ();
377
+
378
+ float const * const vecEnd = vec + width;
379
+ while ( vec < vecEnd )
380
+ {
381
+ __m256 const v = _mm256_broadcast_ss ( vec );
382
+ vec++;
383
+ acc0 = _mm256_fmadd_ps ( v, _mm256_load_ps ( mat ), acc0 );
384
+ acc1 = _mm256_fmadd_ps ( v, _mm256_load_ps ( mat + 8 ), acc1 );
385
+ acc2 = _mm256_fmadd_ps ( v, _mm256_load_ps ( mat + 16 ), acc2 );
386
+ int const distance = 32 *4 ; // 4 fastest for 2048x2048 with 64 threads // prefetching not tuned for 24
387
+ _mm_prefetch ( mat + distance, _MM_HINT_T0 ); // prefetch 16 elements
388
+ _mm_prefetch ( mat + distance + 16 , _MM_HINT_T0 ); // prefetch another 16 elements
389
+ mat += 24 ;
390
+ }
391
+
392
+ _mm256_store_ps ( rdi, acc0 );
393
+ _mm256_store_ps ( rdi + 8 , acc1 );
394
+ _mm256_store_ps ( rdi + 16 , acc2 );
395
+ }
396
+
397
+ void mvm_kernel_avx2_24_f16c ( float const * mat, float const * vec, size_t width, float * rdi )
398
+ {
399
+ __m256 acc0 = _mm256_setzero_ps ();
400
+ __m256 acc1 = _mm256_setzero_ps ();
401
+ __m256 acc2 = _mm256_setzero_ps ();
402
+
403
+ float const * const vecEnd = vec + width;
404
+ while ( vec < vecEnd )
405
+ {
406
+ int const distance = 1 *32 ; // distance doesn't seem to matter here for 3456x3456 matrix, speed up is present as soon as prefetching
407
+ _mm_prefetch ( mat + distance, _MM_HINT_T0 ); // prefetch 32 elements
408
+
409
+ __m256 const v = _mm256_broadcast_ss ( vec );
410
+ vec++;
411
+
412
+ acc0 = _mm256_fmadd_ps ( v, _mm256_cvtph_ps ( _mm_load_si128 ( reinterpret_cast <__m128i const *>( mat ) ) ), acc0 );
413
+ acc1 = _mm256_fmadd_ps ( v, _mm256_cvtph_ps ( _mm_load_si128 ( reinterpret_cast <__m128i const *>( mat + 4 ) ) ), acc1 );
414
+ acc2 = _mm256_fmadd_ps ( v, _mm256_cvtph_ps ( _mm_load_si128 ( reinterpret_cast <__m128i const *>( mat + 8 ) ) ), acc2 );
415
+ mat += 12 ;
416
+ }
417
+
418
+ _mm256_store_ps ( rdi, acc0 );
419
+ _mm256_store_ps ( rdi + 8 , acc1 );
420
+ _mm256_store_ps ( rdi + 16 , acc2 );
421
+ }
422
+
372
423
void mvm_kernel_avx2_32 ( float const * mat, float const * vec, size_t width, float * rdi )
373
424
{
374
425
__m256 acc0 = _mm256_setzero_ps ();
0 commit comments