Skip to content

Commit c610261

Browse files
committed
perf: do not search and dlopen on every sql cmd
This extracts a handle of zen-internal FFI to a thread safe singleton. Ensuring the DLL is searched for and loaded only once. Previously, we would run multiple syscalls to get the location of the dll/so and then built the python datastructures (CDLL object) that has to be later GCed. I haven't microbenchmarked the improvements rigorously, as I feel GC savings may contribute considerably in real-world scenario. Although, poetry run pytest -rP ./aikido_zen/vulnerabilities/sql_injection/ improves by over 4.3% on my set-up.
1 parent c9493d6 commit c610261

File tree

2 files changed

+51
-30
lines changed

2 files changed

+51
-30
lines changed

aikido_zen/vulnerabilities/sql_injection/__init__.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
"""
44

55
import re
6-
import ctypes
76
from aikido_zen.helpers.logging import logger
8-
from .map_dialect_to_rust_int import map_dialect_to_rust_int
9-
from .get_lib_path import get_binary_path
10-
from ...helpers.encode_safely import encode_safely
7+
from .zen_internal_ffi import ZenInternal
118

129

1310
def detect_sql_injection(query, user_input, dialect):
@@ -20,32 +17,7 @@ def detect_sql_injection(query, user_input, dialect):
2017
if should_return_early(query_l, userinput_l):
2118
return False
2219

23-
internals_lib = ctypes.CDLL(get_binary_path())
24-
internals_lib.detect_sql_injection.argtypes = [
25-
ctypes.POINTER(ctypes.c_uint8),
26-
ctypes.c_size_t,
27-
ctypes.POINTER(ctypes.c_uint8),
28-
ctypes.c_size_t,
29-
ctypes.c_int,
30-
]
31-
internals_lib.detect_sql_injection.restype = ctypes.c_int
32-
33-
# Parse input variables for rust function
34-
query_bytes = encode_safely(query_l)
35-
userinput_bytes = encode_safely(userinput_l)
36-
query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes)
37-
userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy(
38-
userinput_bytes
39-
)
40-
dialect_int = map_dialect_to_rust_int(dialect)
41-
42-
c_int_res = internals_lib.detect_sql_injection(
43-
query_buffer,
44-
len(query_bytes),
45-
userinput_buffer,
46-
len(userinput_bytes),
47-
dialect_int,
48-
)
20+
c_int_res = ZenInternal().detect_sql_injection(query_l, userinput_l, dialect)
4921

5022
# This means that an error occurred in the library
5123
if c_int_res == 2:
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
Interface for calling zen-internal shared library
3+
"""
4+
5+
import ctypes
6+
import threading
7+
from .get_lib_path import get_binary_path
8+
from .map_dialect_to_rust_int import map_dialect_to_rust_int
9+
from ...helpers.encode_safely import encode_safely
10+
11+
12+
class __Singleton(type):
13+
_instances = {}
14+
_lock = threading.Lock() # Ensures thread safety
15+
16+
def __call__(cls, *args, **kwargs):
17+
with cls._lock: # Lock to make the check-and-create operation atomic
18+
if cls not in cls._instances:
19+
cls._instances[cls] = super().__call__(*args, **kwargs)
20+
return cls._instances[cls]
21+
22+
23+
class ZenInternal(metaclass=__Singleton):
24+
def __init__(self):
25+
self._lib = ctypes.CDLL(get_binary_path())
26+
self._lib.detect_sql_injection.argtypes = [
27+
ctypes.POINTER(ctypes.c_uint8),
28+
ctypes.c_size_t,
29+
ctypes.POINTER(ctypes.c_uint8),
30+
ctypes.c_size_t,
31+
ctypes.c_int,
32+
]
33+
self._lib.detect_sql_injection.restype = ctypes.c_int
34+
35+
def detect_sql_injection(self, query, user_input, dialect):
36+
query_bytes = encode_safely(query)
37+
userinput_bytes = encode_safely(user_input)
38+
query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes)
39+
userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy(
40+
userinput_bytes
41+
)
42+
dialect_int = map_dialect_to_rust_int(dialect)
43+
return self._lib.detect_sql_injection(
44+
query_buffer,
45+
len(query_bytes),
46+
userinput_buffer,
47+
len(userinput_bytes),
48+
dialect_int,
49+
)

0 commit comments

Comments
 (0)