Skip to content

Commit ba15300

Browse files
authored
perf: Adding thread local s3 cache db connections to fix performance bottleneck during bundle submission (#896)
perf: Enabling WAL mode by default for improved concurrency perf: Adding a get_connection_entry method which accepts a db connection parameter to allow for greater read concurrency. get_connection_entry doesn't take the main db_lock which protects the primary db_connection. perf: Adding a get_local_connection entry which creates (if necessary) a thread local db conneciton suitable for reading entries in get_connection_entry. Signed-off-by: Brian Axelson <86568017+baxeaz@users.noreply.github.com>
1 parent 36ba412 commit ba15300

File tree

6 files changed

+368
-42
lines changed

6 files changed

+368
-42
lines changed

src/deadline/job_attachments/caches/cache_db.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import logging
88
import os
9+
import threading as _threading
910
from abc import ABC
1011
from threading import Lock
1112
from typing import Optional
@@ -34,6 +35,8 @@ def __init__(
3435
self.cache_name: str = cache_name
3536
self.table_name: str = table_name
3637
self.create_query: str = create_query
38+
self._local = _threading.local()
39+
self._local_connections: set = set()
3740

3841
try:
3942
# SQLite is included in Python installers, but might not exist if building python from source.
@@ -64,6 +67,7 @@ def __enter__(self):
6467
self.db_connection: sqlite3.Connection = sqlite3.connect(
6568
self.cache_dir, check_same_thread=False
6669
)
70+
self.db_connection.execute("PRAGMA journal_mode=WAL")
6771
except sqlite3.OperationalError as oe:
6872
raise JobAttachmentsError(
6973
f"Could not access cache file in {self.cache_dir}"
@@ -81,8 +85,35 @@ def __enter__(self):
8185

8286
def __exit__(self, exc_type, exc_value, exc_traceback):
8387
"""Called when exiting the context manager."""
88+
8489
if self.enabled:
90+
import sqlite3
91+
8592
self.db_connection.close()
93+
for conn in self._local_connections:
94+
try:
95+
conn.close()
96+
except sqlite3.Error as e:
97+
logger.warning(f"SQLite connection failed to close with error {e}")
98+
99+
self._local_connections.clear()
100+
101+
def get_local_connection(self):
102+
"""Create and/or returns a thread local connection to the SQLite database."""
103+
if not self.enabled:
104+
return None
105+
import sqlite3
106+
107+
if not hasattr(self._local, "connection"):
108+
try:
109+
self._local.connection = sqlite3.connect(self.cache_dir, check_same_thread=False)
110+
self._local_connections.add(self._local.connection)
111+
except sqlite3.OperationalError as oe:
112+
raise JobAttachmentsError(
113+
f"Could not create connection to cache in {self.cache_dir}"
114+
) from oe
115+
116+
return self._local.connection
86117

87118
@classmethod
88119
def get_default_cache_db_file_dir(cls) -> Optional[str]:
@@ -99,12 +130,23 @@ def remove_cache(self) -> None:
99130
"""
100131
Removes the underlying cache contents from the file system.
101132
"""
133+
102134
if self.enabled:
135+
import sqlite3
136+
103137
self.db_connection.close()
138+
conn_list = list(self._local_connections)
139+
for conn in conn_list:
140+
try:
141+
conn.close()
142+
self._local_connections.remove(conn)
143+
except sqlite3.Error as e:
144+
logger.warning(f"SQLite connection failed to close with error {e}")
104145

105146
logger.debug(f"The cache {self.cache_dir} will be removed")
106147
try:
107148
os.remove(self.cache_dir)
108149
except Exception as e:
109150
logger.error(f"Error occurred while removing the cache file {self.cache_dir}: {e}")
151+
110152
raise e

src/deadline/job_attachments/caches/hash_cache.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,35 @@ def __init__(self, cache_dir: Optional[str] = None) -> None:
5959
cache_dir=cache_dir,
6060
)
6161

62+
def get_connection_entry(
63+
self, file_path_key: str, hash_algorithm: HashAlgorithm, connection
64+
) -> Optional[HashCacheEntry]:
65+
"""
66+
Returns an entry from the hash cache, if it exists. This is the "lockless" (Doesn't take
67+
the main db_lock protecting db_connection) version of get_entry which expects a connection
68+
parameter for the connection which will be used to read from the DB - this can generally
69+
be the thread local connection returned by get_local_connection()
70+
"""
71+
if not self.enabled:
72+
return None
73+
74+
entry_vals = connection.execute(
75+
f"SELECT * FROM {self.table_name} WHERE file_path=? AND hash_algorithm=?",
76+
[
77+
file_path_key.encode(encoding="utf-8", errors="surrogatepass"),
78+
hash_algorithm.value,
79+
],
80+
).fetchone()
81+
if entry_vals:
82+
return HashCacheEntry(
83+
file_path=str(entry_vals[0], encoding="utf-8", errors="surrogatepass"),
84+
hash_algorithm=HashAlgorithm(entry_vals[1]),
85+
file_hash=entry_vals[2],
86+
last_modified_time=str(entry_vals[3]),
87+
)
88+
else:
89+
return None
90+
6291
def get_entry(
6392
self, file_path_key: str, hash_algorithm: HashAlgorithm
6493
) -> Optional[HashCacheEntry]:
@@ -69,22 +98,7 @@ def get_entry(
6998
return None
7099

71100
with self.db_lock, self.db_connection:
72-
entry_vals = self.db_connection.execute(
73-
f"SELECT * FROM {self.table_name} WHERE file_path=? AND hash_algorithm=?",
74-
[
75-
file_path_key.encode(encoding="utf-8", errors="surrogatepass"),
76-
hash_algorithm.value,
77-
],
78-
).fetchone()
79-
if entry_vals:
80-
return HashCacheEntry(
81-
file_path=str(entry_vals[0], encoding="utf-8", errors="surrogatepass"),
82-
hash_algorithm=HashAlgorithm(entry_vals[1]),
83-
file_hash=entry_vals[2],
84-
last_modified_time=str(entry_vals[3]),
85-
)
86-
else:
87-
return None
101+
return self.get_connection_entry(file_path_key, hash_algorithm, self.db_connection)
88102

89103
def put_entry(self, entry: HashCacheEntry) -> None:
90104
"""Inserts or replaces an entry into the hash cache database after acquiring the lock."""

src/deadline/job_attachments/caches/s3_check_cache.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ def __init__(self, cache_dir: Optional[str] = None) -> None:
5656
cache_dir=cache_dir,
5757
)
5858

59+
def get_connection_entry(self, s3_key: str, connection) -> Optional[S3CheckCacheEntry]:
60+
"""
61+
Returns an entry from the hash cache, if it exists. This is the "lockless" (Doesn't take
62+
the main db_lock protecting db_connection) version of get_entry which expects a connection
63+
parameter for the connection which will be used to read from the DB - this can generally
64+
be the thread local connection returned by get_local_connection()
65+
"""
66+
67+
entry_vals = connection.execute(
68+
f"SELECT * FROM {self.table_name} WHERE s3_key=?",
69+
[s3_key],
70+
).fetchone()
71+
if entry_vals:
72+
entry = S3CheckCacheEntry(
73+
s3_key=entry_vals[0],
74+
last_seen_time=str(entry_vals[1]),
75+
)
76+
try:
77+
last_seen = datetime.fromtimestamp(float(entry.last_seen_time))
78+
if (datetime.now() - last_seen).days < self.ENTRY_EXPIRY_DAYS:
79+
return entry
80+
except ValueError:
81+
logger.warning(f"Timestamp for S3 key {s3_key} is not valid. Ignoring.")
82+
83+
return None
84+
5985
def get_entry(self, s3_key: str) -> Optional[S3CheckCacheEntry]:
6086
"""
6187
Checks if an entry exists in the cache, and returns it if it hasn't expired.
@@ -64,23 +90,7 @@ def get_entry(self, s3_key: str) -> Optional[S3CheckCacheEntry]:
6490
return None
6591

6692
with self.db_lock, self.db_connection:
67-
entry_vals = self.db_connection.execute(
68-
f"SELECT * FROM {self.table_name} WHERE s3_key=?",
69-
[s3_key],
70-
).fetchone()
71-
if entry_vals:
72-
entry = S3CheckCacheEntry(
73-
s3_key=entry_vals[0],
74-
last_seen_time=str(entry_vals[1]),
75-
)
76-
try:
77-
last_seen = datetime.fromtimestamp(float(entry.last_seen_time))
78-
if (datetime.now() - last_seen).days < self.ENTRY_EXPIRY_DAYS:
79-
return entry
80-
except ValueError:
81-
logger.warning(f"Timestamp for S3 key {s3_key} is not valid. Ignoring.")
82-
83-
return None
93+
return self.get_connection_entry(s3_key, self.db_connection)
8494

8595
def put_entry(self, entry: S3CheckCacheEntry) -> None:
8696
"""Inserts or replaces an entry into the cache database."""

src/deadline/job_attachments/upload.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,11 @@ def verify_hash_cache_integrity(
593593
random.shuffle(s3_upload_keys)
594594
sampled_cache_entries: List[S3CheckCacheEntry] = []
595595
with S3CheckCache(s3_check_cache_dir) as s3_cache:
596+
local_connection = s3_cache.get_local_connection()
596597
for upload_key in s3_upload_keys:
597-
this_entry = s3_cache.get_entry(s3_key=f"{s3_bucket}/{upload_key}")
598+
this_entry = s3_cache.get_connection_entry(
599+
s3_key=f"{s3_bucket}/{upload_key}", connection=local_connection
600+
)
598601
if this_entry is not None:
599602
sampled_cache_entries.append(this_entry)
600603
if len(sampled_cache_entries) >= 30:
@@ -651,7 +654,9 @@ def upload_object_to_cas(
651654
s3_upload_key = self._generate_s3_upload_key(file, hash_algorithm, s3_cas_prefix)
652655
is_uploaded = False
653656

654-
if s3_check_cache.get_entry(s3_key=f"{s3_bucket}/{s3_upload_key}"):
657+
if s3_check_cache.get_connection_entry(
658+
s3_key=f"{s3_bucket}/{s3_upload_key}", connection=s3_check_cache.get_local_connection()
659+
):
655660
logger.debug(
656661
f"skipping {local_path} because {s3_bucket}/{s3_upload_key} exists in the cache"
657662
)
@@ -1068,7 +1073,9 @@ def _process_input_path(
10681073
file_status: FileStatus = FileStatus.UNCHANGED
10691074
actual_modified_time = str(datetime.fromtimestamp(path.stat().st_mtime))
10701075

1071-
entry: Optional[HashCacheEntry] = hash_cache.get_entry(full_path, hash_alg)
1076+
entry: Optional[HashCacheEntry] = hash_cache.get_connection_entry(
1077+
full_path, hash_alg, connection=hash_cache.get_local_connection()
1078+
)
10721079
if entry is not None:
10731080
# If the file was modified, we need to rehash it
10741081
if actual_modified_time != entry.last_modified_time:

0 commit comments

Comments
 (0)