|
17 | 17 |
|
18 | 18 | #include "intsimdmatrix.h" |
19 | 19 |
|
20 | | -#if !defined(__AVX2__) |
| 20 | +#if !defined(__AVX512VNNI__) || !defined(__AVX512VL__) |
21 | 21 | # if defined(__i686__) || defined(__x86_64__) |
22 | | -# error Implementation only for AVX2 capable architectures |
| 22 | +# error Implementation only for AVX512VNNI capable architectures |
23 | 23 | # endif |
24 | 24 | #else |
25 | 25 | # include <immintrin.h> |
@@ -73,16 +73,12 @@ static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, |
73 | 73 | // Normalize the signs on rep_input, weights, so weights is always +ve. |
74 | 74 | reps = _mm256_sign_epi8(rep_input, weights); |
75 | 75 | weights = _mm256_sign_epi8(weights, weights); |
76 | | - // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results, |
77 | | - // with adjacent pairs added. |
78 | | - weights = _mm256_maddubs_epi16(weights, reps); |
79 | | - // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results, |
80 | | - // with adjacent pairs added. What we really want is a horizontal add of |
81 | | - // 16+16=32 bit result, but there is no such instruction, so multiply by |
82 | | - // 16-bit ones instead. It is probably faster than all the sign-extending, |
83 | | - // permuting and adding that would otherwise be required. |
84 | | - weights = _mm256_madd_epi16(weights, ones); |
85 | | - result = _mm256_add_epi32(result, weights); |
| 76 | + |
| 77 | + // VNNI instruction. It replaces 3 AVX2 instructions: |
| 78 | + //weights = _mm256_maddubs_epi16(weights, reps); |
| 79 | + //weights = _mm256_madd_epi16(weights, ones); |
| 80 | + //result = _mm256_add_epi32(result, weights); |
| 81 | + result = _mm256_dpbusd_epi32(result, weights, reps); |
86 | 82 | } |
87 | 83 |
|
88 | 84 | // Load 64 bits into the bottom of a 128bit register. |
|
0 commit comments