Skip to content

Commit

Permalink
Add NULL check for common
Browse files Browse the repository at this point in the history
Signed-off-by: Songling Han <[email protected]>
  • Loading branch information
songlingatpan committed Sep 23, 2024
1 parent 067bdaf commit 6707331
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 66 deletions.
29 changes: 15 additions & 14 deletions src/common/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ OQS_API int OQS_MEM_secure_bcmp(const void *a, const void *b, size_t len) {
}

OQS_API void OQS_MEM_cleanse(void *ptr, size_t len) {
if (ptr == NULL) {
return;
}
#if defined(OQS_USE_OPENSSL)
OSSL_FUNC(OPENSSL_cleanse)(ptr, len);
#elif defined(_WIN32)
Expand All @@ -267,20 +270,19 @@ OQS_API void OQS_MEM_cleanse(void *ptr, size_t len) {
explicit_memset(ptr, 0, len);
#elif defined(__STDC_LIB_EXT1__) || defined(OQS_HAVE_MEMSET_S)
if (0U < len && memset_s(ptr, (rsize_t)len, 0, (rsize_t)len) != 0) {
abort();
return NULL; //abort();
}
#else
typedef void *(*memset_t)(void *, int, size_t);
static volatile memset_t memset_func = memset;
memset_func(ptr, 0, len);
#endif
}

void *OQS_MEM_checked_malloc(size_t len) {
void *ptr = OQS_MEM_malloc(len);
if (ptr == NULL) {
fprintf(stderr, "Memory allocation failed\n");
abort();
return NULL; //abort();
}

return ptr;
Expand All @@ -290,7 +292,7 @@ void *OQS_MEM_checked_aligned_alloc(size_t alignment, size_t size) {
void *ptr = OQS_MEM_aligned_alloc(alignment, size);
if (ptr == NULL) {
fprintf(stderr, "Memory allocation failed\n");
abort();
return NULL; //abort();
}

return ptr;
Expand Down Expand Up @@ -391,24 +393,23 @@ void *OQS_MEM_aligned_alloc(size_t alignment, size_t size) {
}

void OQS_MEM_aligned_free(void *ptr) {
if (ptr == NULL) {
return;
}
#if defined(OQS_USE_OPENSSL)
// Use OpenSSL's free function
if (ptr) {
uint8_t *u8ptr = ptr;
OPENSSL_free(u8ptr - u8ptr[-1]);
}
uint8_t *u8ptr = ptr;
OPENSSL_free(u8ptr - u8ptr[-1]);
#elif defined(OQS_HAVE_ALIGNED_ALLOC) || defined(OQS_HAVE_POSIX_MEMALIGN) || defined(OQS_HAVE_MEMALIGN)
free(ptr); // IGNORE free-check
#elif defined(__MINGW32__) || defined(__MINGW64__)
__mingw_aligned_free(ptr);
#elif defined(_MSC_VER)
_aligned_free(ptr);
#else
if (ptr) {
// Reconstruct the pointer returned from malloc using the difference
// stored one byte ahead of ptr.
uint8_t *u8ptr = ptr;
free(u8ptr - u8ptr[-1]); // IGNORE free-check
}
// Reconstruct the pointer returned from malloc using the difference
// stored one byte ahead of ptr.
uint8_t *u8ptr = ptr;
free(u8ptr - u8ptr[-1]); // IGNORE free-check
#endif
}
129 changes: 77 additions & 52 deletions src/common/ossl_helpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,57 @@ static EVP_MD *sha256_ptr, *sha384_ptr, *sha512_ptr,

static EVP_CIPHER *aes128_ecb_ptr, *aes128_ctr_ptr, *aes256_ecb_ptr, *aes256_ctr_ptr;

static void free_ossl_objects(void) {
if (sha256_ptr) {
OSSL_FUNC(EVP_MD_free)(sha256_ptr);
sha256_ptr = NULL;
}
if (sha384_ptr) {
OSSL_FUNC(EVP_MD_free)(sha384_ptr);
sha384_ptr = NULL;
}
if (sha512_ptr) {
OSSL_FUNC(EVP_MD_free)(sha512_ptr);
sha512_ptr = NULL;
}
if (sha3_256_ptr) {
OSSL_FUNC(EVP_MD_free)(sha3_256_ptr);
sha3_256_ptr = NULL;
}
if (sha3_384_ptr) {
OSSL_FUNC(EVP_MD_free)(sha3_384_ptr);
sha3_384_ptr = NULL;
}
if (sha3_512_ptr) {
OSSL_FUNC(EVP_MD_free)(sha3_512_ptr);
sha3_512_ptr = NULL;
}
if (shake128_ptr) {
OSSL_FUNC(EVP_MD_free)(shake128_ptr);
shake128_ptr = NULL;
}
if (shake256_ptr) {
OSSL_FUNC(EVP_MD_free)(shake256_ptr);
shake256_ptr = NULL;
}
if (aes128_ecb_ptr) {
OSSL_FUNC(EVP_CIPHER_free)(aes128_ecb_ptr);
aes128_ecb_ptr = NULL;
}
if (aes128_ctr_ptr) {
OSSL_FUNC(EVP_CIPHER_free)(aes128_ctr_ptr);
aes128_ctr_ptr = NULL;
}
if (aes256_ecb_ptr) {
OSSL_FUNC(EVP_CIPHER_free)(aes256_ecb_ptr);
aes256_ecb_ptr = NULL;
}
if (aes256_ctr_ptr) {
OSSL_FUNC(EVP_CIPHER_free)(aes256_ctr_ptr);
aes256_ctr_ptr = NULL;
}
}

static void fetch_ossl_objects(void) {
sha256_ptr = OSSL_FUNC(EVP_MD_fetch)(NULL, "SHA256", NULL);
sha384_ptr = OSSL_FUNC(EVP_MD_fetch)(NULL, "SHA384", NULL);
Expand All @@ -40,47 +91,18 @@ static void fetch_ossl_objects(void) {
!sha3_384_ptr || !sha3_512_ptr || !shake128_ptr || !shake256_ptr ||
!aes128_ecb_ptr || !aes128_ctr_ptr || !aes256_ecb_ptr || !aes256_ctr_ptr) {
fprintf(stderr, "liboqs warning: OpenSSL initialization failure. Is provider for SHA, SHAKE, AES enabled?\n");
free_ossl_objects();
}
}

static void free_ossl_objects(void) {
OSSL_FUNC(EVP_MD_free)(sha256_ptr);
sha256_ptr = NULL;
OSSL_FUNC(EVP_MD_free)(sha384_ptr);
sha384_ptr = NULL;
OSSL_FUNC(EVP_MD_free)(sha512_ptr);
sha512_ptr = NULL;
OSSL_FUNC(EVP_MD_free)(sha3_256_ptr);
sha3_256_ptr = NULL;
OSSL_FUNC(EVP_MD_free)(sha3_384_ptr);
sha3_384_ptr = NULL;
OSSL_FUNC(EVP_MD_free)(sha3_512_ptr);
sha3_512_ptr = NULL;
OSSL_FUNC(EVP_MD_free)(shake128_ptr);
shake128_ptr = NULL;
OSSL_FUNC(EVP_MD_free)(shake256_ptr);
shake256_ptr = NULL;
OSSL_FUNC(EVP_CIPHER_free)(aes128_ecb_ptr);
aes128_ecb_ptr = NULL;
OSSL_FUNC(EVP_CIPHER_free)(aes128_ctr_ptr);
aes128_ctr_ptr = NULL;
OSSL_FUNC(EVP_CIPHER_free)(aes256_ecb_ptr);
aes256_ecb_ptr = NULL;
OSSL_FUNC(EVP_CIPHER_free)(aes256_ctr_ptr);
aes256_ctr_ptr = NULL;
}
#endif // OPENSSL_VERSION_NUMBER >= 0x30000000L

void oqs_ossl_destroy(void) {
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
#if defined(OQS_USE_PTHREADS)
pthread_once(&free_once_control, free_ossl_objects);
#else
if (sha256_ptr || sha384_ptr || sha512_ptr || sha3_256_ptr ||
sha3_384_ptr || sha3_512_ptr || shake128_ptr || shake256_ptr ||
aes128_ecb_ptr || aes128_ctr_ptr || aes256_ecb_ptr || aes256_ctr_ptr) {
free_ossl_objects();
}
free_ossl_objects();
#endif
#endif
}
Expand Down Expand Up @@ -237,7 +259,6 @@ const EVP_CIPHER *oqs_aes_128_ecb(void) {
return OSSL_FUNC(EVP_aes_128_ecb)();
#endif
}

const EVP_CIPHER *oqs_aes_128_ctr(void) {
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
#if defined(OQS_USE_PTHREADS)
Expand Down Expand Up @@ -301,19 +322,19 @@ static pthread_once_t dlopen_once_control = PTHREAD_ONCE_INIT;
#define ENSURE_LIBRARY pthread_once(&dlopen_once_control, ensure_library)
#else
#define ENSURE_LIBRARY do { \
if (!libcrypto_dlhandle) { \
ensure_library(); \
} \
} while (0)
if (!libcrypto_dlhandle) { \
ensure_library(); \
} \
} while (0)
#endif // OQS_USE_PTHREADS

/* Define redirection symbols */
#if (2 <= __GNUC__ || (4 <= __clang_major__))
#define FUNC(ret, name, args, cargs) \
static __typeof__(name)(*_oqs_ossl_sym_##name);
static __typeof__(name)(*_oqs_ossl_sym_##name);
#else
#define FUNC(ret, name, args, cargs) \
static ret(*_oqs_ossl_sym_##name)args;
static ret(*_oqs_ossl_sym_##name)args;
#endif
#define VOID_FUNC FUNC
#include "ossl_functions.h"
Expand All @@ -322,19 +343,23 @@ static pthread_once_t dlopen_once_control = PTHREAD_ONCE_INIT;

/* Define redirection wrapper functions */
#define FUNC(ret, name, args, cargs) \
ret _oqs_ossl_##name args \
{ \
ENSURE_LIBRARY; \
assert(_oqs_ossl_sym_##name); \
return _oqs_ossl_sym_##name cargs; \
}
ret _oqs_ossl_##name args \
{ \
ENSURE_LIBRARY; \
if (!_oqs_ossl_sym_##name) { \
return (ret)0; \
} \
return _oqs_ossl_sym_##name cargs; \
}
#define VOID_FUNC(ret, name, args, cargs) \
ret _oqs_ossl_##name args \
{ \
ENSURE_LIBRARY; \
assert(_oqs_ossl_sym_##name); \
_oqs_ossl_sym_##name cargs; \
}
ret _oqs_ossl_##name args \
{ \
ENSURE_LIBRARY; \
if (!_oqs_ossl_sym_##name) { \
return; \
} \
_oqs_ossl_sym_##name cargs; \
}
#include "ossl_functions.h"
#undef VOID_FUNC
#undef FUNC
Expand All @@ -359,9 +384,9 @@ static void ensure_library(void) {
}

#define ENSURE_SYMBOL(name) \
ensure_symbol(#name, (void **)&_oqs_ossl_sym_##name)
ensure_symbol(#name, (void **)&_oqs_ossl_sym_##name)
#define FUNC(ret, name, args, cargs) \
ENSURE_SYMBOL(name);
ENSURE_SYMBOL(name);
#define VOID_FUNC FUNC
#include "ossl_functions.h"
#undef VOID_FUNC
Expand Down

0 comments on commit 6707331

Please sign in to comment.