diff --git a/ngu/hash.c b/ngu/hash.c index 829bd75..664620f 100644 --- a/ngu/hash.c +++ b/ngu/hash.c @@ -211,25 +211,14 @@ STATIC mp_obj_t hm_hash160(mp_obj_t arg) { STATIC MP_DEFINE_CONST_FUN_OBJ_1(hm_hash160_obj, hm_hash160); -// Pbkdf2 using sha512 hmac, for use in BIP39=>BIP32 seed -STATIC mp_obj_t pbkdf2_sha512(mp_obj_t pass_in, mp_obj_t salt_in, mp_obj_t rounds_in) { +STATIC mp_obj_t pbkdf2_hmac(uint32_t md_len, mp_obj_t pass_in, mp_obj_t salt_in, mp_obj_t rounds_in) { + mp_buffer_info_t pass, salt; mp_get_buffer_raise(pass_in, &pass, MP_BUFFER_READ); mp_get_buffer_raise(salt_in, &salt, MP_BUFFER_READ); - const uint32_t H_SIZE = 64; // because sha512 - -#if MICROPY_SSL_MBEDTLS - const mbedtls_md_info_t *md_algo = mbedtls_md_info_from_type(MBEDTLS_MD_SHA512); -#endif - - vstr_t key_out; - vstr_init_len(&key_out, H_SIZE); - uint32_t key_len = H_SIZE; - uint8_t *key = (uint8_t *)key_out.buf; - - // Based on https://github.com/openbsd/src/blob/master/lib/libutil/pkcs5_pbkdf2.c uint32_t rounds = mp_obj_get_int_truncated(rounds_in); + if(rounds < 1) { mp_raise_ValueError(MP_ERROR_TEXT("rounds")); } @@ -237,11 +226,44 @@ STATIC mp_obj_t pbkdf2_sha512(mp_obj_t pass_in, mp_obj_t salt_in, mp_obj_t round mp_raise_ValueError(MP_ERROR_TEXT("salt")); } - uint8_t d1[H_SIZE], d2[H_SIZE], obuf[H_SIZE]; + vstr_t key_out; + vstr_init_len(&key_out, md_len); + uint8_t *key = (uint8_t *)key_out.buf; + +#if MICROPY_SSL_MBEDTLS + const mbedtls_md_info_t *algo; + switch(md_len) { + case 64: + algo = mbedtls_md_info_from_type(MBEDTLS_MD_SHA512); + break; + case 32: + algo = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); + break; + default: + mp_raise_ValueError(NULL); + } +#else + const cf_chash *algo = NULL; + switch(md_len) { + case 64: + algo = &cf_sha512; + break; + case 32: + algo = &cf_sha256; + break; + default: + mp_raise_ValueError(NULL); + } +#endif + + // Based on https://github.com/openbsd/src/blob/master/lib/libutil/pkcs5_pbkdf2.c + + uint8_t d1[md_len], d2[md_len], obuf[md_len]; uint8_t asalt[salt.len + 4]; memcpy(asalt, salt.buf, salt.len); + uint32_t key_len = md_len; for(uint32_t count=1; key_len > 0; count++) { asalt[salt.len + 0] = (count >> 24) & 0xff; asalt[salt.len + 1] = (count >> 16) & 0xff; @@ -249,27 +271,25 @@ STATIC mp_obj_t pbkdf2_sha512(mp_obj_t pass_in, mp_obj_t salt_in, mp_obj_t round asalt[salt.len + 3] = count & 0xff; #if MICROPY_SSL_MBEDTLS - mbedtls_md_hmac(md_algo, pass.buf, pass.len, asalt, sizeof(asalt), d1); + mbedtls_md_hmac(algo, pass.buf, pass.len, asalt, sizeof(asalt), d1); #else - cf_hmac(pass.buf, pass.len, asalt, sizeof(asalt), d1, &cf_sha512); + cf_hmac(pass.buf, pass.len, asalt, sizeof(asalt), d1, algo); #endif - //hmac_sha256(asalt, salt_len + 4, pass.buf, pass.len, d1); - memcpy(obuf, d1, H_SIZE); + memcpy(obuf, d1, md_len); for(uint32_t i=1; i < rounds; i++) { - //hmac_sha1(d1, sizeof(d1), pass.buf, pass.len, d2); #if MICROPY_SSL_MBEDTLS - mbedtls_md_hmac(md_algo, pass.buf, pass.len, d1, sizeof(d1), d2); + mbedtls_md_hmac(algo, pass.buf, pass.len, d1, sizeof(d1), d2); #else - cf_hmac(pass.buf, pass.len, d1, sizeof(d1), d2, &cf_sha512); + cf_hmac(pass.buf, pass.len, d1, sizeof(d1), d2, algo); #endif memcpy(d1, d2, sizeof(d1)); for (uint32_t j = 0; j < sizeof(obuf); j++) obuf[j] ^= d1[j]; } - uint32_t r = MIN(key_len, H_SIZE); + uint32_t r = MIN(key_len, md_len); memcpy(key, obuf, r); key += r; key_len -= r; @@ -283,8 +303,19 @@ STATIC mp_obj_t pbkdf2_sha512(mp_obj_t pass_in, mp_obj_t salt_in, mp_obj_t round return mp_obj_new_str_from_vstr(&mp_type_bytes, &key_out); } + +// Pbkdf2 using sha512 hmac, for use in BIP39=>BIP32 seed +STATIC mp_obj_t pbkdf2_sha512(mp_obj_t pass_in, mp_obj_t salt_in, mp_obj_t rounds_in) { + return pbkdf2_hmac(64, pass_in, salt_in, rounds_in); +} STATIC MP_DEFINE_CONST_FUN_OBJ_3(pbkdf2_sha512_obj, pbkdf2_sha512); +// Pbkdf2 using sha256 hmac +STATIC mp_obj_t pbkdf2_sha256(mp_obj_t pass_in, mp_obj_t salt_in, mp_obj_t rounds_in) { + return pbkdf2_hmac(32, pass_in, salt_in, rounds_in); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_3(pbkdf2_sha256_obj, pbkdf2_sha256); + STATIC const mp_rom_map_elem_t mp_module_hash_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_hash) }, @@ -296,6 +327,7 @@ STATIC const mp_rom_map_elem_t mp_module_hash_globals_table[] = { { MP_ROM_QSTR(MP_QSTR_sha256t), MP_ROM_PTR(&hm_tagged_sha256_obj) }, { MP_ROM_QSTR(MP_QSTR_hash160), MP_ROM_PTR(&hm_hash160_obj) }, { MP_ROM_QSTR(MP_QSTR_pbkdf2_sha512), MP_ROM_PTR(&pbkdf2_sha512_obj) }, + { MP_ROM_QSTR(MP_QSTR_pbkdf2_sha256), MP_ROM_PTR(&pbkdf2_sha256_obj) }, }; diff --git a/ngu/ngu_tests/test_hash.py b/ngu/ngu_tests/test_hash.py index 9566ab1..6054452 100644 --- a/ngu/ngu_tests/test_hash.py +++ b/ngu/ngu_tests/test_hash.py @@ -42,7 +42,7 @@ def expect2(func, msg, dig): except ImportError: import hashlib from binascii import b2a_hex, a2b_hex - from hashlib import sha512, sha256 + from hashlib import sha512, sha256, pbkdf2_hmac _ripemd160 = lambda x: hashlib.new('ripemd160', x).digest() @@ -54,7 +54,6 @@ def expect2(func, msg, dig): ] # gen tests - import wallycore as w with open('test_hash_gen.py', 'wt') as fd: print("import ngu", file=fd) @@ -66,15 +65,17 @@ def expect2(func, msg, dig): n = 2000 print("assert ngu.hash.%s(bytes(%d)) == %r" % (nm, n, func(bytes(n))), file=fd) - print("F = ngu.hash.pbkdf2_sha512", file=fd) - for pw, salt, rounds in [ - (b'abc', b'def', 300), - (b'abc'*20, b'def'*20, 3000), - (b'\x01\x03\x04\x05\x06', b'\x04\x03\x02\x01\x00', 30), - (b'a', b'd', 30), - ]: - expect = w.pbkdf2_hmac_sha512(pw, salt, 0, rounds) - print("assert F(%r, %r, %d) == %r" % (pw, salt, rounds, bytes(expect)), file=fd) + for hf in ["sha512", "sha256"]: + print("F = ngu.hash.pbkdf2_%s" % hf, file=fd) + for pw, salt, rounds in [ + (b'abc', b'def', 300), + (b'abc'*20, b'def'*20, 3000), + (b'\x01\x03\x04\x05\x06', b'\x04\x03\x02\x01\x00', 30), + (b'a', b'd', 30), + (b"x" * 256, b"y"*128, 100000), + ]: + expect = pbkdf2_hmac(hf, pw, salt, rounds, dklen=32 if hf == "sha256" else 64) + print("assert F(%r, %r, %d) == %r" % (pw, salt, rounds, bytes(expect)), file=fd) print("print('PASS - %s')" % fd.name, file=fd) print("run code now in: %s" % fd.name) diff --git a/ngu/ngu_tests/test_hash_gen.py b/ngu/ngu_tests/test_hash_gen.py index cf6e7e7..5bbe5b1 100644 --- a/ngu/ngu_tests/test_hash_gen.py +++ b/ngu/ngu_tests/test_hash_gen.py @@ -95,4 +95,11 @@ assert F(b'abcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabc', b'defdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdef', 3000) == b' \xfb\x96\x1a]\xea\xcf0\xae;E\x8bZ\x80aS<^\xe0\xf2\x9e\x947;O\x84\xe5\xd4\xaa]F5\x13_D\xfa\x1b\x90\x93\x84j\x8a\x9c\xb6G\x8b\x80{\xac\xb1\x00\x06\x03\xf9(\xff\xa6\x11\x05\xc5@p\xf0%' assert F(b'\x01\x03\x04\x05\x06', b'\x04\x03\x02\x01\x00', 30) == b'\x08\xbeC\xd2\xbe\x02CO\xfe\xb9Nj\xeb%!|\xd9\xef\xb5\xad\x9cU\x03J\x18T\xa9\x00\xe0\xde\xfdFu\x8dz\xe5\xd18)\xfb\x1c\xf8\x02\xd3\x06h\xd9\xe1\xe4\x1c`\xc8\xdcM\x1b\xb6a\x8f\x0c\x9dO,(\xd2' assert F(b'a', b'd', 30) == b'U\x02\xf3b\xb1\x8e\x90\xaa\xb7\xe9\xe2\xd3\xaf\xf3\xd3\xde\xedZ\xf2&\xd7\xec\x94\x13\x19Q\xf0&\x16.\xa6\xfb\x08C\x067\xb3\xa9>$N\xfavb^\xa9\x99\xbc"\x03\xd5c\xf5\xd6!\xba \x16Z\x071_\xc1\x0f' +assert F(b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', b'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy', 100000) == b'R6\xf7\xf8G|\x8b\x9c\xc1\xd7O\x05\x88\xdf\xf5\xf3\x15\xc0\xf1URcARc\xdfz\xc8\xe7f\xea\x97\x12\x98M\xaf\xb6\xe2Q_\xf3"\xdf9\x10UK\xe6V\x9f\xba\x887\xd2P\x934\x81\x1aN\xff\xebL\xa7' +F = ngu.hash.pbkdf2_sha256 +assert F(b'abc', b'def', 300) == b'\x14|\x11\xe5v\xb8\x8c\x1e\xa5\x12xn[O\x15%\xac>\xfe"\x9d(\xdf\xffQ$\xd65-\xfa\xc89' +assert F(b'abcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabc', b'defdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdef', 3000) == b'\xb8\xbe=\x12\xd8\xef#\x9f\x9b\xc5\xa8z)\xab*D{\xa0_\xaeZ\x91\xb5]<\xd4IpyD[1' +assert F(b'\x01\x03\x04\x05\x06', b'\x04\x03\x02\x01\x00', 30) == b'@\x1c\xc2\x06\xd6-\xdd\xde\x11\xd54\xabd\x08\xfb\x9d\xa8ET\xc9U\t\x9a\x8e.-\xa3\xf5\x9d\x8c\x00r' +assert F(b'a', b'd', 30) == b"-\xe1(\x16)\x0f\x03\x14\xe8\xc0S\xcc\x06'\x9a\xb8\x8ea\x06s\x0f\xf0\x9c\xa4.Q\xf7\xd1\xc4\x9c@\x8a" +assert F(b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', b'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy', 100000) == b'\nW\xf0\xc6KD"C\xb7\xba\xd9\t\xa7"%QZ;\x8d\xc9-\x99D\x92\x8f`5\xd3\xcf\xa6\x9cA' print('PASS - test_hash_gen.py')