-
Notifications
You must be signed in to change notification settings - Fork 4.4k
rotaryembed simd x86 #6427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
futz12
wants to merge
4
commits into
Tencent:master
Choose a base branch
from
futz12:RotaryEmbed_x86
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+339
−0
Open
rotaryembed simd x86 #6427
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,316 @@ | ||
| // Copyright 2025 pchar.cn | ||
| // SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| #include "rotaryembed_x86.h" | ||
|
|
||
| #if __SSE2__ | ||
| #include <emmintrin.h> | ||
| #if __AVX__ | ||
| #include <immintrin.h> | ||
| #endif // __AVX__ | ||
| #if __AVX512F__ | ||
| #include <immintrin.h> | ||
| #endif // __AVX512F__ | ||
| #endif // __SSE2__ | ||
|
|
||
| namespace ncnn { | ||
|
|
||
| RotaryEmbed_x86::RotaryEmbed_x86() | ||
| { | ||
| } | ||
|
|
||
| int RotaryEmbed_x86::forward(const std::vector<Mat>& bottom_blobs, | ||
| std::vector<Mat>& top_blobs, | ||
| const Option& opt) const | ||
| { | ||
| const Mat& bottom_blob = bottom_blobs[0]; | ||
| const Mat& cos_cache = bottom_blobs[1]; | ||
| const Mat& sin_cache = bottom_blobs[2]; | ||
|
|
||
| const int embed_dim = bottom_blob.w; | ||
| const int seqlen = bottom_blob.h; | ||
| const int num_heads = bottom_blob.c; | ||
|
|
||
| Mat& top_blob = top_blobs[0]; | ||
| top_blob.create_like(bottom_blob, opt.blob_allocator); | ||
| if (top_blob.empty()) | ||
| return -100; | ||
|
|
||
| #pragma omp parallel for num_threads(opt.num_threads) | ||
| for (int q = 0; q < num_heads; q++) | ||
| { | ||
| const Mat head = bottom_blob.channel(q); | ||
| Mat out_head = top_blob.channel(q); | ||
|
|
||
| for (int i = 0; i < seqlen; i++) | ||
| { | ||
| if (interleaved) | ||
| { | ||
| const float* ptr = head.row(i); | ||
| const float* cos_ptr = cos_cache.row(i); | ||
| const float* sin_ptr = sin_cache.row(i); | ||
| float* outptr = out_head.row(i); | ||
|
|
||
| int j = 0; | ||
|
|
||
| #if __SSE2__ | ||
| #if __AVX512F__ | ||
| { | ||
| const __m512 signmask512 = _mm512_castsi512_ps(_mm512_set_epi32( | ||
| 0, (int)0x80000000, 0, (int)0x80000000, 0, (int)0x80000000, 0, (int)0x80000000, | ||
| 0, (int)0x80000000, 0, (int)0x80000000, 0, (int)0x80000000, 0, (int)0x80000000)); | ||
|
|
||
| const __m512i dupidx = _mm512_set_epi32( | ||
| 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); | ||
|
|
||
| for (; j + 7 < embed_dim / 2; j += 8) | ||
| { | ||
| __m512 a = _mm512_loadu_ps(ptr); | ||
|
|
||
| __m256 c8 = _mm256_loadu_ps(cos_ptr); | ||
| __m256 s8 = _mm256_loadu_ps(sin_ptr); | ||
|
|
||
| __m512 csrc = _mm512_castps256_ps512(c8); | ||
| __m512 ssrc = _mm512_castps256_ps512(s8); | ||
|
|
||
| __m512 c = _mm512_permutexvar_ps(dupidx, csrc); | ||
| __m512 s = _mm512_permutexvar_ps(dupidx, ssrc); | ||
|
|
||
| __m512 ac = _mm512_mul_ps(a, c); | ||
|
|
||
| __m512 swap = _mm512_shuffle_ps(a, a, _MM_SHUFFLE(2, 3, 0, 1)); | ||
| __m512 ss = _mm512_mul_ps(swap, s); | ||
|
|
||
| ss = _mm512_xor_ps(ss, signmask512); | ||
|
|
||
| __m512 y = _mm512_add_ps(ac, ss); | ||
| _mm512_storeu_ps(outptr, y); | ||
|
|
||
| ptr += 16; | ||
| outptr += 16; | ||
| cos_ptr += 8; | ||
| sin_ptr += 8; | ||
| } | ||
| } | ||
| #endif // __AVX512F__ | ||
|
|
||
| #if __AVX2__ | ||
| { | ||
| const __m256 signmask256 = _mm256_castsi256_ps(_mm256_set_epi32( | ||
| 0, (int)0x80000000, 0, (int)0x80000000, 0, (int)0x80000000, 0, (int)0x80000000)); | ||
|
|
||
| const __m256i dupidx256 = _mm256_set_epi32(3, 3, 2, 2, 1, 1, 0, 0); | ||
|
|
||
| for (; j + 3 < embed_dim / 2; j += 4) | ||
| { | ||
| __m256 a = _mm256_loadu_ps(ptr); | ||
|
|
||
| __m128 c4 = _mm_loadu_ps(cos_ptr); | ||
| __m128 s4 = _mm_loadu_ps(sin_ptr); | ||
|
|
||
| __m256 csrc = _mm256_castps128_ps256(c4); | ||
| __m256 ssrc = _mm256_castps128_ps256(s4); | ||
|
|
||
| __m256 c = _mm256_permutevar8x32_ps(csrc, dupidx256); | ||
| __m256 s = _mm256_permutevar8x32_ps(ssrc, dupidx256); | ||
|
|
||
| __m256 ac = _mm256_mul_ps(a, c); | ||
|
|
||
| __m256 swap = _mm256_shuffle_ps(a, a, _MM_SHUFFLE(2, 3, 0, 1)); | ||
| __m256 ss = _mm256_mul_ps(swap, s); | ||
|
|
||
| ss = _mm256_xor_ps(ss, signmask256); | ||
|
|
||
| __m256 y = _mm256_add_ps(ac, ss); | ||
| _mm256_storeu_ps(outptr, y); | ||
|
|
||
| ptr += 8; | ||
| outptr += 8; | ||
| cos_ptr += 4; | ||
| sin_ptr += 4; | ||
| } | ||
| } | ||
| #elif __AVX__ | ||
| { | ||
| const __m256 signmask256 = _mm256_castsi256_ps(_mm256_set_epi32( | ||
| 0, (int)0x80000000, 0, (int)0x80000000, 0, (int)0x80000000, 0, (int)0x80000000)); | ||
|
|
||
| for (; j + 3 < embed_dim / 2; j += 4) | ||
| { | ||
| __m256 a = _mm256_loadu_ps(ptr); | ||
|
|
||
| __m128 c4 = _mm_loadu_ps(cos_ptr); | ||
| __m128 s4 = _mm_loadu_ps(sin_ptr); | ||
|
|
||
| __m128 clo = _mm_unpacklo_ps(c4, c4); // [c0,c0,c1,c1] | ||
| __m128 chi = _mm_unpackhi_ps(c4, c4); // [c2,c2,c3,c3] | ||
| __m128 slo = _mm_unpacklo_ps(s4, s4); // [s0,s0,s1,s1] | ||
| __m128 shi = _mm_unpackhi_ps(s4, s4); // [s2,s2,s3,s3] | ||
|
|
||
| __m256 c = _mm256_castps128_ps256(clo); | ||
| c = _mm256_insertf128_ps(c, chi, 1); | ||
|
|
||
| __m256 s = _mm256_castps128_ps256(slo); | ||
| s = _mm256_insertf128_ps(s, shi, 1); | ||
|
|
||
| __m256 ac = _mm256_mul_ps(a, c); | ||
|
|
||
| __m256 swap = _mm256_shuffle_ps(a, a, _MM_SHUFFLE(2, 3, 0, 1)); | ||
| __m256 ss = _mm256_mul_ps(swap, s); | ||
|
|
||
| ss = _mm256_xor_ps(ss, signmask256); | ||
|
|
||
| __m256 y = _mm256_add_ps(ac, ss); | ||
| _mm256_storeu_ps(outptr, y); | ||
|
|
||
| ptr += 8; | ||
| outptr += 8; | ||
| cos_ptr += 4; | ||
| sin_ptr += 4; | ||
| } | ||
| } | ||
| #endif // __AVX__ | ||
|
|
||
| { | ||
| const __m128 signmask128 = _mm_castsi128_ps(_mm_set_epi32( | ||
| 0, (int)0x80000000, 0, (int)0x80000000)); | ||
|
|
||
| for (; j + 1 < embed_dim / 2; j += 2) | ||
| { | ||
| __m128 a = _mm_loadu_ps(ptr); | ||
|
|
||
| __m128 c01 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)cos_ptr)); | ||
| __m128 s01 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)sin_ptr)); | ||
|
|
||
| __m128 c = _mm_unpacklo_ps(c01, c01); // [c0,c0,c1,c1] | ||
| __m128 s = _mm_unpacklo_ps(s01, s01); // [s0,s0,s1,s1] | ||
|
|
||
| __m128 ac = _mm_mul_ps(a, c); | ||
|
|
||
| __m128 swap = _mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 3, 0, 1)); | ||
| __m128 ss = _mm_mul_ps(swap, s); | ||
|
|
||
| ss = _mm_xor_ps(ss, signmask128); | ||
|
|
||
| __m128 y = _mm_add_ps(ac, ss); | ||
| _mm_storeu_ps(outptr, y); | ||
|
|
||
| ptr += 4; | ||
| outptr += 4; | ||
| cos_ptr += 2; | ||
| sin_ptr += 2; | ||
| } | ||
| } | ||
| #endif // __SSE2__ | ||
|
|
||
| for (; j < embed_dim / 2; j++) | ||
| { | ||
| const float x0 = ptr[0]; | ||
| const float x1 = ptr[1]; | ||
| const float cos_val = *cos_ptr++; | ||
| const float sin_val = *sin_ptr++; | ||
|
|
||
| outptr[0] = x0 * cos_val - x1 * sin_val; | ||
| outptr[1] = x0 * sin_val + x1 * cos_val; | ||
|
|
||
| ptr += 2; | ||
| outptr += 2; | ||
| } | ||
| } | ||
| else | ||
| { | ||
| const float* ptr0 = head.row(i); | ||
| const float* ptr1 = ptr0 + embed_dim / 2; | ||
| const float* cos_ptr = cos_cache.row(i); | ||
| const float* sin_ptr = sin_cache.row(i); | ||
|
|
||
| float* outptr0 = out_head.row(i); | ||
| float* outptr1 = outptr0 + embed_dim / 2; | ||
|
|
||
| int j = 0; | ||
|
|
||
| #if __SSE2__ | ||
| #if __AVX512F__ | ||
| for (; j + 15 < embed_dim / 2; j += 16) | ||
| { | ||
| __m512 x0 = _mm512_loadu_ps(ptr0); | ||
| __m512 x1 = _mm512_loadu_ps(ptr1); | ||
| __m512 c = _mm512_loadu_ps(cos_ptr); | ||
| __m512 s = _mm512_loadu_ps(sin_ptr); | ||
|
|
||
| __m512 y0 = _mm512_sub_ps(_mm512_mul_ps(x0, c), _mm512_mul_ps(x1, s)); | ||
| __m512 y1 = _mm512_add_ps(_mm512_mul_ps(x0, s), _mm512_mul_ps(x1, c)); | ||
|
|
||
| _mm512_storeu_ps(outptr0, y0); | ||
| _mm512_storeu_ps(outptr1, y1); | ||
|
|
||
| ptr0 += 16; | ||
| ptr1 += 16; | ||
| cos_ptr += 16; | ||
| sin_ptr += 16; | ||
| outptr0 += 16; | ||
| outptr1 += 16; | ||
| } | ||
| #elif __AVX__ | ||
| for (; j + 7 < embed_dim / 2; j += 8) | ||
| { | ||
| __m256 x0 = _mm256_loadu_ps(ptr0); | ||
| __m256 x1 = _mm256_loadu_ps(ptr1); | ||
| __m256 c = _mm256_loadu_ps(cos_ptr); | ||
| __m256 s = _mm256_loadu_ps(sin_ptr); | ||
|
|
||
| __m256 y0 = _mm256_sub_ps(_mm256_mul_ps(x0, c), _mm256_mul_ps(x1, s)); | ||
| __m256 y1 = _mm256_add_ps(_mm256_mul_ps(x0, s), _mm256_mul_ps(x1, c)); | ||
|
|
||
| _mm256_storeu_ps(outptr0, y0); | ||
| _mm256_storeu_ps(outptr1, y1); | ||
|
|
||
| ptr0 += 8; | ||
| ptr1 += 8; | ||
| cos_ptr += 8; | ||
| sin_ptr += 8; | ||
| outptr0 += 8; | ||
| outptr1 += 8; | ||
| } | ||
| #endif // __AVX__ | ||
|
|
||
| for (; j + 3 < embed_dim / 2; j += 4) | ||
| { | ||
| __m128 x0 = _mm_loadu_ps(ptr0); | ||
| __m128 x1 = _mm_loadu_ps(ptr1); | ||
| __m128 c = _mm_loadu_ps(cos_ptr); | ||
| __m128 s = _mm_loadu_ps(sin_ptr); | ||
|
|
||
| __m128 y0 = _mm_sub_ps(_mm_mul_ps(x0, c), _mm_mul_ps(x1, s)); | ||
| __m128 y1 = _mm_add_ps(_mm_mul_ps(x0, s), _mm_mul_ps(x1, c)); | ||
|
|
||
| _mm_storeu_ps(outptr0, y0); | ||
| _mm_storeu_ps(outptr1, y1); | ||
|
|
||
| ptr0 += 4; | ||
| ptr1 += 4; | ||
| cos_ptr += 4; | ||
| sin_ptr += 4; | ||
| outptr0 += 4; | ||
| outptr1 += 4; | ||
| } | ||
| #endif // __SSE2__ | ||
|
|
||
| for (; j < embed_dim / 2; j++) | ||
| { | ||
| const float x0 = *ptr0++; | ||
| const float x1 = *ptr1++; | ||
| const float cos_val = *cos_ptr++; | ||
| const float sin_val = *sin_ptr++; | ||
|
|
||
| *outptr0++ = x0 * cos_val - x1 * sin_val; | ||
| *outptr1++ = x0 * sin_val + x1 * cos_val; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return 0; | ||
| } | ||
|
|
||
| } // namespace ncnn | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| // Copyright 2025 pchar.cn | ||
| // SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| #ifndef LAYER_ROTARYEMBED_X86_H | ||
| #define LAYER_ROTARYEMBED_X86_H | ||
|
|
||
| #include "rotaryembed.h" | ||
|
|
||
| namespace ncnn { | ||
|
|
||
| class RotaryEmbed_x86 : public RotaryEmbed | ||
| { | ||
| public: | ||
| RotaryEmbed_x86(); | ||
|
|
||
| virtual int forward(const std::vector<Mat>& bottom_blobs, | ||
| std::vector<Mat>& top_blobs, | ||
| const Option& opt) const; | ||
| }; | ||
|
|
||
| } // namespace ncnn | ||
|
|
||
| #endif // LAYER_ROTARYEMBED_X86_H |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.