diff --git a/aikido_zen/vulnerabilities/sql_injection/__init__.py b/aikido_zen/vulnerabilities/sql_injection/__init__.py index 745c8b5b..745241fc 100644 --- a/aikido_zen/vulnerabilities/sql_injection/__init__.py +++ b/aikido_zen/vulnerabilities/sql_injection/__init__.py @@ -3,11 +3,8 @@ """ import re -import ctypes from aikido_zen.helpers.logging import logger -from .map_dialect_to_rust_int import map_dialect_to_rust_int -from .get_lib_path import get_binary_path -from ...helpers.encode_safely import encode_safely +from .zen_internal_ffi import ZenInternal def detect_sql_injection(query, user_input, dialect): @@ -20,32 +17,7 @@ def detect_sql_injection(query, user_input, dialect): if should_return_early(query_l, userinput_l): return False - internals_lib = ctypes.CDLL(get_binary_path()) - internals_lib.detect_sql_injection.argtypes = [ - ctypes.POINTER(ctypes.c_uint8), - ctypes.c_size_t, - ctypes.POINTER(ctypes.c_uint8), - ctypes.c_size_t, - ctypes.c_int, - ] - internals_lib.detect_sql_injection.restype = ctypes.c_int - - # Parse input variables for rust function - query_bytes = encode_safely(query_l) - userinput_bytes = encode_safely(userinput_l) - query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes) - userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy( - userinput_bytes - ) - dialect_int = map_dialect_to_rust_int(dialect) - - c_int_res = internals_lib.detect_sql_injection( - query_buffer, - len(query_bytes), - userinput_buffer, - len(userinput_bytes), - dialect_int, - ) + c_int_res = ZenInternal().detect_sql_injection(query_l, userinput_l, dialect) # This means that an error occurred in the library if c_int_res == 2: diff --git a/aikido_zen/vulnerabilities/sql_injection/map_dialect_to_rust_int.py b/aikido_zen/vulnerabilities/sql_injection/map_dialect_to_rust_int.py deleted file mode 100644 index b18b96b6..00000000 --- a/aikido_zen/vulnerabilities/sql_injection/map_dialect_to_rust_int.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Exports map_dialect_to_rust_int -""" - -DIALECTS = { - "generic": 0, - "clickhouse": 3, - "mysql": 8, - "postgres": 9, - "sqlite": 12, -} - - -def map_dialect_to_rust_int(dialect): - """ - This takes the string dialect as input and maps it to a rust integer - Reference : [rust lib]/src/sql_injection/helpers/select_dialect_based_on_enum.rs - """ - return DIALECTS[dialect] diff --git a/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py b/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py new file mode 100644 index 00000000..b046c299 --- /dev/null +++ b/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py @@ -0,0 +1,57 @@ +""" +Interface for calling zen-internal shared library +""" + +import ctypes +import threading +from .get_lib_path import get_binary_path +from ...helpers.encode_safely import encode_safely + + +class __Singleton(type): + _instances = {} + _lock = threading.Lock() # Ensures thread safety + + def __call__(cls, *args, **kwargs): + with cls._lock: # Lock to make the check-and-create operation atomic + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +class ZenInternal(metaclass=__Singleton): + # Reference : [rust lib]/src/sql_injection/helpers/select_dialect_based_on_enum.rs + SQL_DIALECTS = { + "generic": 0, + "clickhouse": 3, + "mysql": 8, + "postgres": 9, + "sqlite": 12, + } + + def __init__(self): + self._lib = ctypes.CDLL(get_binary_path()) + self._lib.detect_sql_injection.argtypes = [ + ctypes.POINTER(ctypes.c_uint8), + ctypes.c_size_t, + ctypes.POINTER(ctypes.c_uint8), + ctypes.c_size_t, + ctypes.c_int, + ] + self._lib.detect_sql_injection.restype = ctypes.c_int + + def detect_sql_injection(self, query, user_input, dialect): + query_bytes = encode_safely(query) + userinput_bytes = encode_safely(user_input) + query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes) + userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy( + userinput_bytes + ) + dialect_int = self.SQL_DIALECTS[dialect] + return self._lib.detect_sql_injection( + query_buffer, + len(query_bytes), + userinput_buffer, + len(userinput_bytes), + dialect_int, + )