Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
c: use AVX and SSE to bulk scan
Browse files Browse the repository at this point in the history
Use fancy instructions to scan 16, 32, and 64 byte ranges instead of
only inspecting a single byte at a time.

While this was a lot of fun to do, it turns out to not be as efficient
as being clever about avoiding comparisons whenever possible. That is,
reading 1/10 bytes is better than reading 10 at once even if they are
the same number of instructions. This is because there is overhead in
loading the 128 and 256 bit registers and that overhead reduces the
gains enough to give us a net speed that is slightly slower.
rmg committed Aug 6, 2019
1 parent 84ddea5 commit 873323b
Showing 1 changed file with 230 additions and 0 deletions.
230 changes: 230 additions & 0 deletions main.c
Original file line number Diff line number Diff line change
@@ -22,6 +22,9 @@ limitations under the License.
#include <stdbool.h>
#include <stdint.h>
#include <assert.h>
#if USE_SIMD && (__SSE4__ || __AVX2__)
#include <immintrin.h>
#endif

#ifndef DO_INST
#define DO_INST 0
@@ -43,6 +46,10 @@ limitations under the License.
static int_fast32_t runlens[4096] = {0};
static int_fast32_t skips[128] = {0};
static int_fast32_t remainders[64] = {0};
static int_fast32_t non_asciis32 = 0;
static int_fast32_t non_asciis16 = 0;
static int_fast32_t non_asciis8 = 0;
static int_fast32_t non_asciis4 = 0;
#endif

#ifndef NDEBUG
@@ -58,6 +65,32 @@ limitations under the License.
# define likely(x) __builtin_expect((x), 1)
# define unlikely(x) __builtin_expect((x), 0)

typedef union mask {
uint8_t bytes[32];
uint8_t u8;
uint16_t u16;
uint32_t u32;
uint64_t u64;
__uint128_t u128;
#if USE_SIMD && __SSE4__
__m128i m128i;
#endif
#if USE_SIMD && __AVX2__
__m256i m128i;
#endif
} mask;

#if USE_SIMD
const static mask non_ascii = {
.bytes = {
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
}
};
#endif

#if USE_HEX_TABLE
static const bool lhex[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -106,6 +139,55 @@ static void print_hit(const unsigned char *buf) {
printf("%.40s\n", buf);
}

#if USE_NON_ASCII4
// perform (0x80 & C) on next 4 bytes at once
// 32-bit operation
static bool non_ascii4(const unsigned char *b) {
uint32_t *bs = (uint32_t*)b;
const uint32_t non_ascii = 0xffffffff;
return (*bs & non_ascii) == non_ascii;
}
#endif

#if USE_SIMD
// perform (0x80 & C) on next 8 bytes at once
// 64-bit operation
static bool non_ascii8(const unsigned char *b) {
uint64_t *bs = (uint64_t*)b;
return (*bs & non_ascii.u64) == non_ascii.u64;
}
#endif

#if USE_SIMD && __SSE4__
// perform (0x80 & C) on next 16 bytes at once
// 128-bit SSE operation
static bool non_ascii16(const unsigned char *b) {
// load 16 bytes as packed 16x8 byte value
__m128i bs = _mm_loadu_si128((const __m128i *)b);
// get the high bit (0x80) from each byte and put it in
// to a the low 16 bites of a 32-bit int as a mask
int high_bits = _mm_movemask_epi8(bs) & 0xffff;
// any high bits mean a non-ascii
// we only care if they are all high.
return high_bits == 0xffff;
}
#endif

#if MAX_STEP >= 32 && USE_SIMD && __AVX2__
// perform (0x80 & C) on next 16 bytes at once 256-bit AVX operation
static bool non_ascii32(const unsigned char *b) {
// load 32 bytes
__m256i bs = _mm256_loadu_si256((const __m256i *)b);
// vpmovmskb is an awesome instruction. It gathers the MSBs from the input
// as packed bytes and returns it as a mask. That's equivalent to &0x80 on
// 32 bytes at once!
int high_bits = _mm256_movemask_epi8(bs);
// any byte with the high bit set cannot be a hex because
// it is outside of the main ascii range;
return high_bits == 0xffffffff;
}
#endif

// At the start of this function, buf is pointing at a non-hex character and the
// goal is to find the next hex character.
static const unsigned char * scan_skip(const unsigned char *buf, const unsigned char *end) {
@@ -114,13 +196,75 @@ static const unsigned char * scan_skip(const unsigned char *buf, const unsigned
#ifndef NDEBUG
const unsigned char * io = buf;
#endif

#if USE_SIMD
if (unlikely(buf + 41 >= end)) {
return buf;
}

#define MAX_STEP 8

while (skip > 0 && buf + skip + MAX_STEP < end) {
// Runs of 32+ and 16+ non-ascii bytes are not common
// enough to justify the overhead of using these
#if MAX_STEP >= 32
if (non_ascii32(buf+skip)) {
buf += skip + 32;
skip = 40;
INST(non_asciis32++);
continue;
}
#endif
#if MAX_STEP >= 16
if (non_ascii16(buf+skip)) {
buf += skip + 16;
skip = 40;
INST(non_asciis16++);
continue;
}
#endif
#if MAX_STEP >= 8
if (non_ascii8(buf+skip)) {
buf += skip + 8;
skip = 40;
INST(non_asciis8++);
continue;
}
#endif
// this works but hits so few cases that it doesn't give any benefit
#if USE_NON_ASCII4
if (non_ascii4(buf+skip)) {
buf += skip + 4;
skip = 40;
INST(non_asciis4++);
continue;
}
#endif
if (!is_lower_hex(buf+skip)) {
buf += skip;
skip = 40;
continue;
}
skip /= 2;
}

while (skip > 0 && buf + skip < end) {
if (!is_lower_hex(buf+skip)) {
buf += skip;
skip = 40;
continue;
}
skip /= 2;
}
#else
do {
while (buf + skip < end && !is_lower_hex(buf+skip)) {
buf += skip;
skip = 40;
}
skip /= 2;
} while (skip > 1 && buf + skip < end);
#endif
assert(io <= buf);
assert(buf < end);
return buf+1;
@@ -183,6 +327,11 @@ static const unsigned char * scan_hit_long(const unsigned char *buf, const unsig
// at 50 we know that the current run ends before then and that any runs
// between here and there are too short to care about.

// a sha256 would have ended at buf+24 so buf+25 wouldn't be a hex
if (!is_lower_hex(buf+25) ) {
return scan_skip(buf+25, end);
}

assert(buf +30 < end);

if (!is_lower_hex(buf+30)) {
@@ -207,6 +356,66 @@ static const unsigned char * scan_hit_long(const unsigned char *buf, const unsig
return scan_hit_short(start, end);
}

#if USE_SIMD && __AVX2__
static int is_hex64(const unsigned char *start) {
uint64_t mask, res;
int pos;

const __m256i b0 = _mm256_loadu_si256((void*)start);
const __m256i b1 = _mm256_loadu_si256((void*)(start+32));

const __m256i rr0 = _mm256_set1_epi8('0'-1);
const __m256i rr1 = _mm256_set1_epi8('9');
const __m256i rr2 = _mm256_set1_epi8('a'-1);
const __m256i rr3 = _mm256_set1_epi8('f');

// x > 0x29
__m256i gz0 = _mm256_cmpgt_epi8(b0, rr0);
__m256i gz1 = _mm256_cmpgt_epi8(b1, rr0);
// .. &! (>0x39)
__m256i le9_0 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b0, rr1), gz0);
__m256i le9_1 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b1, rr1), gz1);
// x > 0x60
__m256i ga0 = _mm256_cmpgt_epi8(b0, rr2);
__m256i ga1 = _mm256_cmpgt_epi8(b1, rr2);
// .. &!(>0x66)
__m256i lef0 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b0, rr3), ga0);
__m256i lef1 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b1, rr3), ga1);

/* Generate bit masks */
unsigned int numeric0 = _mm256_movemask_epi8(le9_0);
unsigned int numeric1 = _mm256_movemask_epi8(le9_1);
unsigned int alpha1 = _mm256_movemask_epi8(lef1);
unsigned int alpha0 = _mm256_movemask_epi8(lef0);

// x > 0x29 && !(x > 0x39) || x > 0x60 && !(x > 0x66)
uint64_t res0 = numeric0 | alpha0;
uint64_t res1 = numeric1 | alpha1;
// [0-31] | [32-63]
res = res0 | (res1 << 32);

// yay little endian! :-/
// 64.............0
// 0x00000080ffffffff
// 0x ffffffff 0-32
// 0x ff 33-40
// 0x 1 41
// 0x000001ffffffffff = mask
// 0x???????????????? & res
// 0x000000ffffffffff = hit!

// bool hit = (res & 0x000001ffffffffff) == 0x000000ffffffffff;

mask = 1;
pos = 0;
while (res & mask) {
pos++;
mask <<= 1;
}
return pos;
}
#endif

// We are at the first hex character. The goal is to determine as efficiently as
// possible if this is a 40 hex character run terminated by a non-hex, something
// shorter, or something longer.
@@ -220,6 +429,23 @@ static const unsigned char * scan_hit_short(const unsigned char *buf, const unsi
return buf;
}

// Use AVX2 instructions to check 32 bytes + 32 bytes
#if USE_SIMD && __AVX2__
if (likely(buf + 64 < end)) {
int len = is_hex64(buf);
assert(len > 0);
assert(len <= 64);
if (len == 40) {
print_hit(buf);
return scan_skip(buf+len, end);
}
if (len < 64) {
return scan_skip(buf+len, end);
}
return scan_hit_long(buf+40, end);
}
#endif

// We know offset 0 is a hex because that's why we're here.
// We know offset 40 needs to be a non-hex otherwise we're in a 41+ run.
// We know 1-39 all need to be hex characters.
@@ -372,6 +598,10 @@ int main(int argc, const char *argv[]) {
for (int i = 0; i < arr_len(runlens); i++)
if (runlens[i])
dprintf(2, " [%4d] %10d%s\n", i, runlens[i], i==40 ? " *" : "");
dprintf(2, "non-ascii32: %10d\n", non_asciis32);
dprintf(2, "non-ascii16: %10d\n", non_asciis16);
dprintf(2, "non-ascii8: %10d\n", non_asciis8);
dprintf(2, "non-ascii4: %10d\n", non_asciis4);
#endif

return nread;

0 comments on commit 873323b

Please sign in to comment.