Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d8ae3e3

Browse files
committedMay 23, 2024
Add AVX2_24 kernels.
1 parent 5060a8e commit d8ae3e3

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed
 

‎include/matilda.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace Matilda
3939
enum class ThreadTask : int { Initialize, ComputeMVM, SpinWait, StayBusy, Sleep, Terminate, Dead };
4040
enum class MvmKernel : int {
4141
#ifdef HAS_AVX2
42-
AVX2_16=25616, AVX2_16_A=256161, AVX2_16_B=256162, AVX2_32=25632, AVX2_40=25640, AVX2_48=25648, AVX2_56=25656, AVX2_64=25664,
42+
AVX2_16=25616, AVX2_16_A=256161, AVX2_16_B=256162, AVX2_24=25624, AVX2_32=25632, AVX2_40=25640, AVX2_48=25648, AVX2_56=25656, AVX2_64=25664,
4343
#endif
4444
#ifdef HAS_AVX512F
4545
AVX512_16=51216, AVX512_32=51232, AVX512_64=51264,

‎src/matilda.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ mvm_plan::mvm_plan( const mvm_param & p ) :
127127
m_simd_rows = 16;
128128
mvm_kernel_func = m_f16c ? &mvm_kernel_avx2_16_b_f16c : &mvm_kernel_avx2_16_b;
129129
break;
130+
case MvmKernel::AVX2_24:
131+
m_simd_rows = 24;
132+
mvm_kernel_func = m_f16c ? &mvm_kernel_avx2_24_f16c : &mvm_kernel_avx2_24;
133+
break;
130134
case MvmKernel::AVX2_32:
131135
m_simd_rows = 32;
132136
mvm_kernel_func = m_f16c ? &mvm_kernel_avx2_32_f16c : &mvm_kernel_avx2_32;

‎src/mvm_impl_avx2.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,57 @@ void mvm_kernel_avx2_16_f16c( float const * mat, float const * vec, size_t width
369369
_mm256_store_ps( rdi + 8, acc1 );
370370
}
371371

372+
void mvm_kernel_avx2_24( float const * mat, float const * vec, size_t width, float * rdi )
373+
{
374+
__m256 acc0 = _mm256_setzero_ps();
375+
__m256 acc1 = _mm256_setzero_ps();
376+
__m256 acc2 = _mm256_setzero_ps();
377+
378+
float const * const vecEnd = vec + width;
379+
while( vec < vecEnd )
380+
{
381+
__m256 const v = _mm256_broadcast_ss( vec );
382+
vec++;
383+
acc0 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), acc0 );
384+
acc1 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), acc1 );
385+
acc2 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 16 ), acc2 );
386+
int const distance = 32*4; // 4 fastest for 2048x2048 with 64 threads // prefetching not tuned for 24
387+
_mm_prefetch( mat + distance, _MM_HINT_T0 ); // prefetch 16 elements
388+
_mm_prefetch( mat + distance + 16, _MM_HINT_T0 ); // prefetch another 16 elements
389+
mat += 24;
390+
}
391+
392+
_mm256_store_ps( rdi, acc0 );
393+
_mm256_store_ps( rdi + 8, acc1 );
394+
_mm256_store_ps( rdi + 16, acc2 );
395+
}
396+
397+
void mvm_kernel_avx2_24_f16c( float const * mat, float const * vec, size_t width, float * rdi )
398+
{
399+
__m256 acc0 = _mm256_setzero_ps();
400+
__m256 acc1 = _mm256_setzero_ps();
401+
__m256 acc2 = _mm256_setzero_ps();
402+
403+
float const * const vecEnd = vec + width;
404+
while( vec < vecEnd )
405+
{
406+
int const distance = 1*32; // distance doesn't seem to matter here for 3456x3456 matrix, speed up is present as soon as prefetching
407+
_mm_prefetch( mat + distance, _MM_HINT_T0 ); // prefetch 32 elements
408+
409+
__m256 const v = _mm256_broadcast_ss( vec );
410+
vec++;
411+
412+
acc0 = _mm256_fmadd_ps( v, _mm256_cvtph_ps( _mm_load_si128( reinterpret_cast<__m128i const *>( mat ) ) ), acc0 );
413+
acc1 = _mm256_fmadd_ps( v, _mm256_cvtph_ps( _mm_load_si128( reinterpret_cast<__m128i const *>( mat + 4 ) ) ), acc1 );
414+
acc2 = _mm256_fmadd_ps( v, _mm256_cvtph_ps( _mm_load_si128( reinterpret_cast<__m128i const *>( mat + 8 ) ) ), acc2 );
415+
mat += 12;
416+
}
417+
418+
_mm256_store_ps( rdi, acc0 );
419+
_mm256_store_ps( rdi + 8, acc1 );
420+
_mm256_store_ps( rdi + 16, acc2 );
421+
}
422+
372423
void mvm_kernel_avx2_32( float const * mat, float const * vec, size_t width, float * rdi )
373424
{
374425
__m256 acc0 = _mm256_setzero_ps();

0 commit comments

Comments
 (0)
Please sign in to comment.