diff --git a/main.c b/main.c index 615081f..ad725f6 100644 --- a/main.c +++ b/main.c @@ -22,6 +22,9 @@ limitations under the License. #include #include #include +#if USE_SIMD && (__SSE4__ || __AVX2__) +#include +#endif #ifndef DO_INST #define DO_INST 0 @@ -43,6 +46,10 @@ limitations under the License. static int_fast32_t runlens[4096] = {0}; static int_fast32_t skips[128] = {0}; static int_fast32_t remainders[64] = {0}; + static int_fast32_t non_asciis32 = 0; + static int_fast32_t non_asciis16 = 0; + static int_fast32_t non_asciis8 = 0; + static int_fast32_t non_asciis4 = 0; #endif #ifndef NDEBUG @@ -58,6 +65,32 @@ limitations under the License. # define likely(x) __builtin_expect((x), 1) # define unlikely(x) __builtin_expect((x), 0) +typedef union mask { + uint8_t bytes[32]; + uint8_t u8; + uint16_t u16; + uint32_t u32; + uint64_t u64; + __uint128_t u128; +#if USE_SIMD && __SSE4__ + __m128i m128i; +#endif +#if USE_SIMD && __AVX2__ + __m256i m128i; +#endif +} mask; + +#if USE_SIMD +const static mask non_ascii = { + .bytes = { + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + } +}; +#endif + #if USE_HEX_TABLE static const bool lhex[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -106,6 +139,55 @@ static void print_hit(const unsigned char *buf) { printf("%.40s\n", buf); } +#if USE_NON_ASCII4 +// perform (0x80 & C) on next 4 bytes at once +// 32-bit operation +static bool non_ascii4(const unsigned char *b) { + uint32_t *bs = (uint32_t*)b; + const uint32_t non_ascii = 0xffffffff; + return (*bs & non_ascii) == non_ascii; +} +#endif + +#if USE_SIMD +// perform (0x80 & C) on next 8 bytes at once +// 64-bit operation +static bool non_ascii8(const unsigned char *b) { + uint64_t *bs = (uint64_t*)b; + return (*bs & non_ascii.u64) == non_ascii.u64; +} +#endif + +#if USE_SIMD && __SSE4__ +// perform (0x80 & C) on next 16 bytes at once +// 128-bit SSE operation +static bool non_ascii16(const unsigned char *b) { + // load 16 bytes as packed 16x8 byte value + __m128i bs = _mm_loadu_si128((const __m128i *)b); + // get the high bit (0x80) from each byte and put it in + // to a the low 16 bites of a 32-bit int as a mask + int high_bits = _mm_movemask_epi8(bs) & 0xffff; + // any high bits mean a non-ascii + // we only care if they are all high. + return high_bits == 0xffff; +} +#endif + +#if MAX_STEP >= 32 && USE_SIMD && __AVX2__ +// perform (0x80 & C) on next 16 bytes at once 256-bit AVX operation +static bool non_ascii32(const unsigned char *b) { + // load 32 bytes + __m256i bs = _mm256_loadu_si256((const __m256i *)b); + // vpmovmskb is an awesome instruction. It gathers the MSBs from the input + // as packed bytes and returns it as a mask. That's equivalent to &0x80 on + // 32 bytes at once! + int high_bits = _mm256_movemask_epi8(bs); + // any byte with the high bit set cannot be a hex because + // it is outside of the main ascii range; + return high_bits == 0xffffffff; +} +#endif + // At the start of this function, buf is pointing at a non-hex character and the // goal is to find the next hex character. static const unsigned char * scan_skip(const unsigned char *buf, const unsigned char *end) { @@ -114,6 +196,67 @@ static const unsigned char * scan_skip(const unsigned char *buf, const unsigned #ifndef NDEBUG const unsigned char * io = buf; #endif + +#if USE_SIMD + if (unlikely(buf + 41 >= end)) { + return buf; + } + +#define MAX_STEP 8 + + while (skip > 0 && buf + skip + MAX_STEP < end) { + // Runs of 32+ and 16+ non-ascii bytes are not common + // enough to justify the overhead of using these +#if MAX_STEP >= 32 + if (non_ascii32(buf+skip)) { + buf += skip + 32; + skip = 40; + INST(non_asciis32++); + continue; + } +#endif +#if MAX_STEP >= 16 + if (non_ascii16(buf+skip)) { + buf += skip + 16; + skip = 40; + INST(non_asciis16++); + continue; + } +#endif +#if MAX_STEP >= 8 + if (non_ascii8(buf+skip)) { + buf += skip + 8; + skip = 40; + INST(non_asciis8++); + continue; + } +#endif + // this works but hits so few cases that it doesn't give any benefit +#if USE_NON_ASCII4 + if (non_ascii4(buf+skip)) { + buf += skip + 4; + skip = 40; + INST(non_asciis4++); + continue; + } +#endif + if (!is_lower_hex(buf+skip)) { + buf += skip; + skip = 40; + continue; + } + skip /= 2; + } + + while (skip > 0 && buf + skip < end) { + if (!is_lower_hex(buf+skip)) { + buf += skip; + skip = 40; + continue; + } + skip /= 2; + } +#else do { while (buf + skip < end && !is_lower_hex(buf+skip)) { buf += skip; @@ -121,6 +264,7 @@ static const unsigned char * scan_skip(const unsigned char *buf, const unsigned } skip /= 2; } while (skip > 1 && buf + skip < end); +#endif assert(io <= buf); assert(buf < end); return buf+1; @@ -183,6 +327,11 @@ static const unsigned char * scan_hit_long(const unsigned char *buf, const unsig // at 50 we know that the current run ends before then and that any runs // between here and there are too short to care about. + // a sha256 would have ended at buf+24 so buf+25 wouldn't be a hex + if (!is_lower_hex(buf+25) ) { + return scan_skip(buf+25, end); + } + assert(buf +30 < end); if (!is_lower_hex(buf+30)) { @@ -207,6 +356,66 @@ static const unsigned char * scan_hit_long(const unsigned char *buf, const unsig return scan_hit_short(start, end); } +#if USE_SIMD && __AVX2__ +static int is_hex64(const unsigned char *start) { + uint64_t mask, res; + int pos; + + const __m256i b0 = _mm256_loadu_si256((void*)start); + const __m256i b1 = _mm256_loadu_si256((void*)(start+32)); + + const __m256i rr0 = _mm256_set1_epi8('0'-1); + const __m256i rr1 = _mm256_set1_epi8('9'); + const __m256i rr2 = _mm256_set1_epi8('a'-1); + const __m256i rr3 = _mm256_set1_epi8('f'); + + // x > 0x29 + __m256i gz0 = _mm256_cmpgt_epi8(b0, rr0); + __m256i gz1 = _mm256_cmpgt_epi8(b1, rr0); + // .. &! (>0x39) + __m256i le9_0 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b0, rr1), gz0); + __m256i le9_1 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b1, rr1), gz1); + // x > 0x60 + __m256i ga0 = _mm256_cmpgt_epi8(b0, rr2); + __m256i ga1 = _mm256_cmpgt_epi8(b1, rr2); + // .. &!(>0x66) + __m256i lef0 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b0, rr3), ga0); + __m256i lef1 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b1, rr3), ga1); + + /* Generate bit masks */ + unsigned int numeric0 = _mm256_movemask_epi8(le9_0); + unsigned int numeric1 = _mm256_movemask_epi8(le9_1); + unsigned int alpha1 = _mm256_movemask_epi8(lef1); + unsigned int alpha0 = _mm256_movemask_epi8(lef0); + + // x > 0x29 && !(x > 0x39) || x > 0x60 && !(x > 0x66) + uint64_t res0 = numeric0 | alpha0; + uint64_t res1 = numeric1 | alpha1; + // [0-31] | [32-63] + res = res0 | (res1 << 32); + + // yay little endian! :-/ + // 64.............0 + // 0x00000080ffffffff + // 0x ffffffff 0-32 + // 0x ff 33-40 + // 0x 1 41 + // 0x000001ffffffffff = mask + // 0x???????????????? & res + // 0x000000ffffffffff = hit! + + // bool hit = (res & 0x000001ffffffffff) == 0x000000ffffffffff; + + mask = 1; + pos = 0; + while (res & mask) { + pos++; + mask <<= 1; + } + return pos; +} +#endif + // We are at the first hex character. The goal is to determine as efficiently as // possible if this is a 40 hex character run terminated by a non-hex, something // shorter, or something longer. @@ -220,6 +429,23 @@ static const unsigned char * scan_hit_short(const unsigned char *buf, const unsi return buf; } + // Use AVX2 instructions to check 32 bytes + 32 bytes +#if USE_SIMD && __AVX2__ + if (likely(buf + 64 < end)) { + int len = is_hex64(buf); + assert(len > 0); + assert(len <= 64); + if (len == 40) { + print_hit(buf); + return scan_skip(buf+len, end); + } + if (len < 64) { + return scan_skip(buf+len, end); + } + return scan_hit_long(buf+40, end); + } +#endif + // We know offset 0 is a hex because that's why we're here. // We know offset 40 needs to be a non-hex otherwise we're in a 41+ run. // We know 1-39 all need to be hex characters. @@ -372,6 +598,10 @@ int main(int argc, const char *argv[]) { for (int i = 0; i < arr_len(runlens); i++) if (runlens[i]) dprintf(2, " [%4d] %10d%s\n", i, runlens[i], i==40 ? " *" : ""); + dprintf(2, "non-ascii32: %10d\n", non_asciis32); + dprintf(2, "non-ascii16: %10d\n", non_asciis16); + dprintf(2, "non-ascii8: %10d\n", non_asciis8); + dprintf(2, "non-ascii4: %10d\n", non_asciis4); #endif return nread;