Skip to content

Commit 201780f

Browse files
committed
c: use AVX and SSE to bulk scan
Use fancy instructions to scan 16, 32, and 64 byte ranges instead of only inspecting a single byte at a time. While this was a lot of fun to do, it turns out to not be as efficient as being clever about avoiding comparisons whenever possible. That is, reading 1/10 bytes is better than reading 10 at once even if they are the same number of instructions. This is because there is overhead in loading the 128 and 256 bit registers and that overhead reduces the gains enough to give us a net speed that is slightly slower.
1 parent b37c491 commit 201780f

File tree

1 file changed

+230
-0
lines changed

1 file changed

+230
-0
lines changed

main.c

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include <stdbool.h>
77
#include <stdint.h>
88
#include <assert.h>
9+
#if USE_SIMD && (__SSE4__ || __AVX2__)
10+
#include <immintrin.h>
11+
#endif
912

1013
#ifndef DO_INST
1114
#define DO_INST 0
@@ -27,6 +30,10 @@
2730
static int_fast32_t runlens[4096] = {0};
2831
static int_fast32_t skips[128] = {0};
2932
static int_fast32_t remainders[64] = {0};
33+
static int_fast32_t non_asciis32 = 0;
34+
static int_fast32_t non_asciis16 = 0;
35+
static int_fast32_t non_asciis8 = 0;
36+
static int_fast32_t non_asciis4 = 0;
3037
#endif
3138

3239
#ifndef NDEBUG
@@ -42,6 +49,32 @@
4249
# define likely(x) __builtin_expect((x), 1)
4350
# define unlikely(x) __builtin_expect((x), 0)
4451

52+
typedef union mask {
53+
uint8_t bytes[32];
54+
uint8_t u8;
55+
uint16_t u16;
56+
uint32_t u32;
57+
uint64_t u64;
58+
__uint128_t u128;
59+
#if USE_SIMD && __SSE4__
60+
__m128i m128i;
61+
#endif
62+
#if USE_SIMD && __AVX2__
63+
__m256i m128i;
64+
#endif
65+
} mask;
66+
67+
#if USE_SIMD
68+
const static mask non_ascii = {
69+
.bytes = {
70+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
71+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
72+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
73+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
74+
}
75+
};
76+
#endif
77+
4578
#if USE_HEX_TABLE
4679
static const bool lhex[256] = {
4780
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -89,6 +122,55 @@ static void print_hit(const unsigned char *buf) {
89122
printf("%.40s\n", buf);
90123
}
91124

125+
#if USE_NON_ASCII4
126+
// perform (0x80 & C) on next 4 bytes at once
127+
// 32-bit operation
128+
static bool non_ascii4(const unsigned char *b) {
129+
uint32_t *bs = (uint32_t*)b;
130+
const uint32_t non_ascii = 0xffffffff;
131+
return (*bs & non_ascii) == non_ascii;
132+
}
133+
#endif
134+
135+
#if USE_SIMD
136+
// perform (0x80 & C) on next 8 bytes at once
137+
// 64-bit operation
138+
static bool non_ascii8(const unsigned char *b) {
139+
uint64_t *bs = (uint64_t*)b;
140+
return (*bs & non_ascii.u64) == non_ascii.u64;
141+
}
142+
#endif
143+
144+
#if USE_SIMD && __SSE4__
145+
// perform (0x80 & C) on next 16 bytes at once
146+
// 128-bit SSE operation
147+
static bool non_ascii16(const unsigned char *b) {
148+
// load 16 bytes as packed 16x8 byte value
149+
__m128i bs = _mm_loadu_si128((const __m128i *)b);
150+
// get the high bit (0x80) from each byte and put it in
151+
// to a the low 16 bites of a 32-bit int as a mask
152+
int high_bits = _mm_movemask_epi8(bs) & 0xffff;
153+
// any high bits mean a non-ascii
154+
// we only care if they are all high.
155+
return high_bits == 0xffff;
156+
}
157+
#endif
158+
159+
#if MAX_STEP >= 32 && USE_SIMD && __AVX2__
160+
// perform (0x80 & C) on next 16 bytes at once 256-bit AVX operation
161+
static bool non_ascii32(const unsigned char *b) {
162+
// load 32 bytes
163+
__m256i bs = _mm256_loadu_si256((const __m256i *)b);
164+
// vpmovmskb is an awesome instruction. It gathers the MSBs from the input
165+
// as packed bytes and returns it as a mask. That's equivalent to &0x80 on
166+
// 32 bytes at once!
167+
int high_bits = _mm256_movemask_epi8(bs);
168+
// any byte with the high bit set cannot be a hex because
169+
// it is outside of the main ascii range;
170+
return high_bits == 0xffffffff;
171+
}
172+
#endif
173+
92174
// At the start of this function, buf is pointing at a non-hex character and the
93175
// goal is to find the next hex character.
94176
static const unsigned char * scan_skip(const unsigned char *buf, const unsigned char *end) {
@@ -97,13 +179,75 @@ static const unsigned char * scan_skip(const unsigned char *buf, const unsigned
97179
#ifndef NDEBUG
98180
const unsigned char * io = buf;
99181
#endif
182+
183+
#if USE_SIMD
184+
if (unlikely(buf + 41 >= end)) {
185+
return buf;
186+
}
187+
188+
#define MAX_STEP 8
189+
190+
while (skip > 0 && buf + skip + MAX_STEP < end) {
191+
// Runs of 32+ and 16+ non-ascii bytes are not common
192+
// enough to justify the overhead of using these
193+
#if MAX_STEP >= 32
194+
if (non_ascii32(buf+skip)) {
195+
buf += skip + 32;
196+
skip = 40;
197+
INST(non_asciis32++);
198+
continue;
199+
}
200+
#endif
201+
#if MAX_STEP >= 16
202+
if (non_ascii16(buf+skip)) {
203+
buf += skip + 16;
204+
skip = 40;
205+
INST(non_asciis16++);
206+
continue;
207+
}
208+
#endif
209+
#if MAX_STEP >= 8
210+
if (non_ascii8(buf+skip)) {
211+
buf += skip + 8;
212+
skip = 40;
213+
INST(non_asciis8++);
214+
continue;
215+
}
216+
#endif
217+
// this works but hits so few cases that it doesn't give any benefit
218+
#if USE_NON_ASCII4
219+
if (non_ascii4(buf+skip)) {
220+
buf += skip + 4;
221+
skip = 40;
222+
INST(non_asciis4++);
223+
continue;
224+
}
225+
#endif
226+
if (!is_lower_hex(buf+skip)) {
227+
buf += skip;
228+
skip = 40;
229+
continue;
230+
}
231+
skip /= 2;
232+
}
233+
234+
while (skip > 0 && buf + skip < end) {
235+
if (!is_lower_hex(buf+skip)) {
236+
buf += skip;
237+
skip = 40;
238+
continue;
239+
}
240+
skip /= 2;
241+
}
242+
#else
100243
do {
101244
while (buf + skip < end && !is_lower_hex(buf+skip)) {
102245
buf += skip;
103246
skip = 40;
104247
}
105248
skip /= 2;
106249
} while (skip > 1 && buf + skip < end);
250+
#endif
107251
assert(io <= buf);
108252
assert(buf < end);
109253
return buf+1;
@@ -166,6 +310,11 @@ static const unsigned char * scan_hit_long(const unsigned char *buf, const unsig
166310
// at 50 we know that the current run ends before then and that any runs
167311
// between here and there are too short to care about.
168312

313+
// a sha256 would have ended at buf+24 so buf+25 wouldn't be a hex
314+
if (!is_lower_hex(buf+25) ) {
315+
return scan_skip(buf+25, end);
316+
}
317+
169318
assert(buf +30 < end);
170319

171320
if (!is_lower_hex(buf+30)) {
@@ -190,6 +339,66 @@ static const unsigned char * scan_hit_long(const unsigned char *buf, const unsig
190339
return scan_hit_short(start, end);
191340
}
192341

342+
#if USE_SIMD && __AVX2__
343+
static int is_hex64(const unsigned char *start) {
344+
uint64_t mask, res;
345+
int pos;
346+
347+
const __m256i b0 = _mm256_loadu_si256((void*)start);
348+
const __m256i b1 = _mm256_loadu_si256((void*)(start+32));
349+
350+
const __m256i rr0 = _mm256_set1_epi8('0'-1);
351+
const __m256i rr1 = _mm256_set1_epi8('9');
352+
const __m256i rr2 = _mm256_set1_epi8('a'-1);
353+
const __m256i rr3 = _mm256_set1_epi8('f');
354+
355+
// x > 0x29
356+
__m256i gz0 = _mm256_cmpgt_epi8(b0, rr0);
357+
__m256i gz1 = _mm256_cmpgt_epi8(b1, rr0);
358+
// .. &! (>0x39)
359+
__m256i le9_0 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b0, rr1), gz0);
360+
__m256i le9_1 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b1, rr1), gz1);
361+
// x > 0x60
362+
__m256i ga0 = _mm256_cmpgt_epi8(b0, rr2);
363+
__m256i ga1 = _mm256_cmpgt_epi8(b1, rr2);
364+
// .. &!(>0x66)
365+
__m256i lef0 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b0, rr3), ga0);
366+
__m256i lef1 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b1, rr3), ga1);
367+
368+
/* Generate bit masks */
369+
unsigned int numeric0 = _mm256_movemask_epi8(le9_0);
370+
unsigned int numeric1 = _mm256_movemask_epi8(le9_1);
371+
unsigned int alpha1 = _mm256_movemask_epi8(lef1);
372+
unsigned int alpha0 = _mm256_movemask_epi8(lef0);
373+
374+
// x > 0x29 && !(x > 0x39) || x > 0x60 && !(x > 0x66)
375+
uint64_t res0 = numeric0 | alpha0;
376+
uint64_t res1 = numeric1 | alpha1;
377+
// [0-31] | [32-63]
378+
res = res0 | (res1 << 32);
379+
380+
// yay little endian! :-/
381+
// 64.............0
382+
// 0x00000080ffffffff
383+
// 0x ffffffff 0-32
384+
// 0x ff 33-40
385+
// 0x 1 41
386+
// 0x000001ffffffffff = mask
387+
// 0x???????????????? & res
388+
// 0x000000ffffffffff = hit!
389+
390+
// bool hit = (res & 0x000001ffffffffff) == 0x000000ffffffffff;
391+
392+
mask = 1;
393+
pos = 0;
394+
while (res & mask) {
395+
pos++;
396+
mask <<= 1;
397+
}
398+
return pos;
399+
}
400+
#endif
401+
193402
// We are at the first hex character. The goal is to determine as efficiently as
194403
// possible if this is a 40 hex character run terminated by a non-hex, something
195404
// shorter, or something longer.
@@ -205,6 +414,23 @@ static const unsigned char * scan_hit_short(const unsigned char *buf, const unsi
205414
return buf;
206415
}
207416

417+
// Use AVX2 instructions to check 32 bytes + 32 bytes
418+
#if USE_SIMD && __AVX2__
419+
if (likely(buf + 64 < end)) {
420+
int len = is_hex64(buf);
421+
assert(len > 0);
422+
assert(len <= 64);
423+
if (len == 40) {
424+
print_hit(buf);
425+
return scan_skip(buf+len, end);
426+
}
427+
if (len < 64) {
428+
return scan_skip(buf+len, end);
429+
}
430+
return scan_hit_long(buf+40, end);
431+
}
432+
#endif
433+
208434
// Unrolled checking the next 40 bytes (must be terminated). We know the
209435
// most frequent lengths of short hex strings, so we for those first by
210436
// looking at N+1
@@ -353,6 +579,10 @@ int main(int argc, const char *argv[]) {
353579
for (int i = 0; i < arr_len(runlens); i++)
354580
if (runlens[i])
355581
dprintf(2, " [%4d] %10d%s\n", i, runlens[i], i==40 ? " *" : "");
582+
dprintf(2, "non-ascii32: %10d\n", non_asciis32);
583+
dprintf(2, "non-ascii16: %10d\n", non_asciis16);
584+
dprintf(2, "non-ascii8: %10d\n", non_asciis8);
585+
dprintf(2, "non-ascii4: %10d\n", non_asciis4);
356586
#endif
357587

358588
return nread;

0 commit comments

Comments
 (0)