Skip to content

Commit

Permalink
Merge pull request #1023 from Yelp/u/kkasp/TRON-2342-exponential-back…
Browse files Browse the repository at this point in the history
…off-dynamo-get

Add dynamodb retry config for throttling and other errors. Add exponential backoff and jitter for unprocessed keys. Fix edge case where we succesfully process keys on our last attempt but still fail
  • Loading branch information
KaspariK authored Feb 12, 2025
2 parents 98c1879 + e0f2cce commit ed5c7a7
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 81 deletions.
107 changes: 65 additions & 42 deletions tests/serialize/runstate/dynamodb_state_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from testifycompat import assert_equal
from tron.serialize.runstate.dynamodb_state_store import DynamoDBStateStore
from tron.serialize.runstate.dynamodb_state_store import MAX_UNPROCESSED_KEYS_RETRIES


def mock_transact_write_items(self):
Expand Down Expand Up @@ -294,58 +295,80 @@ def test_delete_item_with_json_partitions(self, store, small_object, large_objec
vals = store.restore([key])
assert key not in vals

def test_retry_saving(self, store, small_object, large_object):
with mock.patch(
"moto.dynamodb2.responses.DynamoHandler.transact_write_items",
side_effect=KeyError("foo"),
) as mock_failed_write:
keys = [store.build_key("job_state", i) for i in range(1)]
value = small_object
pairs = zip(keys, (value for i in range(len(keys))))
try:
store.save(pairs)
except Exception:
assert_equal(mock_failed_write.call_count, 3)

def test_retry_reading(self, store, small_object, large_object):
@pytest.mark.parametrize(
"test_object, side_effects, expected_save_errors, expected_queue_length",
[
# All attempts fail
("small_object", [KeyError("foo")] * 3, 3, 1),
("large_object", [KeyError("foo")] * 3, 3, 1),
# Failure followed by success
("small_object", [KeyError("foo"), {}], 0, 0),
("large_object", [KeyError("foo"), {}], 0, 0),
],
)
def test_retry_saving(
self, test_object, side_effects, expected_save_errors, expected_queue_length, store, small_object, large_object
):
object_mapping = {
"small_object": small_object,
"large_object": large_object,
}
value = object_mapping[test_object]

with mock.patch.object(
store.client,
"transact_write_items",
side_effect=side_effects,
) as mock_transact_write:
keys = [store.build_key("job_state", 0)]
pairs = zip(keys, [value])
store.save(pairs)

for _ in side_effects:
store._consume_save_queue()

assert mock_transact_write.call_count == len(side_effects)
assert store.save_errors == expected_save_errors
assert len(store.save_queue) == expected_queue_length

@pytest.mark.parametrize(
"attempt, expected_delay",
[
(1, 1),
(2, 2),
(3, 4),
(4, 8),
(5, 10),
(6, 10),
(7, 10),
],
)
def test_calculate_backoff_delay(self, store, attempt, expected_delay):
delay = store._calculate_backoff_delay(attempt)
assert_equal(delay, expected_delay)

def test_retry_reading(self, store):
unprocessed_value = {
"Responses": {
store.name: [
{
"index": {"N": "0"},
"key": {"S": "job_state 0"},
},
],
},
"Responses": {},
"UnprocessedKeys": {
store.name: {
"Keys": [{"key": {"S": store.build_key("job_state", 0)}, "index": {"N": "0"}}],
"ConsistentRead": True,
"Keys": [
{
"index": {"N": "0"},
"key": {"S": "job_state 0"},
}
],
},
}
},
"ResponseMetadata": {},
}
keys = [store.build_key("job_state", i) for i in range(1)]
value = small_object
pairs = zip(keys, (value for i in range(len(keys))))
store.save(pairs)

keys = [store.build_key("job_state", 0)]

with mock.patch.object(
store.client,
"batch_get_item",
return_value=unprocessed_value,
) as mock_failed_read:
try:
with mock.patch("tron.config.static_config.load_yaml_file", autospec=True), mock.patch(
"tron.config.static_config.build_configuration_watcher", autospec=True
):
store.restore(keys)
except Exception:
assert_equal(mock_failed_read.call_count, 11)
) as mock_batch_get_item, mock.patch("time.sleep") as mock_sleep, pytest.raises(Exception) as exec_info:
store.restore(keys)
assert "failed to retrieve items with keys" in str(exec_info.value)
assert mock_batch_get_item.call_count == MAX_UNPROCESSED_KEYS_RETRIES
assert mock_sleep.call_count == MAX_UNPROCESSED_KEYS_RETRIES

def test_restore_exception_propagation(self, store, small_object):
# This test is to ensure that restore propagates exceptions upwards: see DAR-2328
Expand Down
124 changes: 85 additions & 39 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import TypeVar

import boto3 # type: ignore
import botocore # type: ignore
from botocore.config import Config # type: ignore

import tron.prom_metrics as prom_metrics
from tron.core.job import Job
Expand All @@ -35,16 +37,34 @@
# to contain other attributes like object name and number of partitions.
OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles
MAX_SAVE_QUEUE = 500
MAX_ATTEMPTS = 10
# This is distinct from the number of retries in the retry_config as this is used for handling unprocessed
# keys outside the bounds of something like retrying on a ThrottlingException. We need this limit to avoid
# infinite loops in the case where a key is truly unprocessable. We allow for more retries than it should
# ever take to avoid failing restores due to transient issues.
MAX_UNPROCESSED_KEYS_RETRIES = 30
MAX_TRANSACT_WRITE_ITEMS = 100
log = logging.getLogger(__name__)
T = TypeVar("T")


class DynamoDBStateStore:
def __init__(self, name, dynamodb_region, stopping=False) -> None:
self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region)
self.client = boto3.client("dynamodb", region_name=dynamodb_region)
# Standard mode includes an exponential backoff by a base factor of 2 for a
# maximum backoff time of 20 seconds (min(b*r^i, MAX_BACKOFF) where b is a
# random number between 0 and 1 and r is the base factor of 2). This might
# look like:
#
# seconds_to_sleep = min(1 × 2^1, 20) = min(2, 20) = 2 seconds
#
# By our 5th retry (2^5 is 32) we will be sleeping *up to* 20 seconds, depending
# on the random jitter.
#
# It handles transient errors like RequestTimeout and ConnectionError, as well
# as Service-side errors like Throttling, SlowDown, and LimitExceeded.
retry_config = Config(retries={"max_attempts": 5, "mode": "standard"})

self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region, config=retry_config)
self.client = boto3.client("dynamodb", region_name=dynamodb_region, config=retry_config)
self.name = name
self.dynamodb_region = dynamodb_region
self.table = self.dynamodb.Table(name)
Expand All @@ -63,11 +83,11 @@ def build_key(self, type, iden) -> str:

def restore(self, keys, read_json: bool = False) -> dict:
"""
Fetch all under the same parition key(s).
Fetch all under the same partition key(s).
ret: <dict of key to states>
"""
# format of the keys always passed here is
# job_state job_name --> high level info about the job: enabled, run_nums
# job_state job_name --> high level info about the job: enabled, run_nums
# job_run_state job_run_name --> high level info about the job run
first_items = self._get_first_partitions(keys)
remaining_items = self._get_remaining_partitions(first_items, read_json)
Expand All @@ -83,12 +103,22 @@ def chunk_keys(self, keys: Sequence[T]) -> List[Sequence[T]]:
cand_keys_chunks.append(keys[i : min(len(keys), i + 100)])
return cand_keys_chunks

def _calculate_backoff_delay(self, attempt: int) -> int:
# Clamp attempt to 1 to avoid negative or zero exponent
safe_attempt = max(attempt, 1)
base_delay_seconds = 1
max_delay_seconds = 10
delay: int = min(base_delay_seconds * (2 ** (safe_attempt - 1)), max_delay_seconds)
return delay

def _get_items(self, table_keys: list) -> object:
items = []
# let's avoid potentially mutating our input :)
cand_keys_list = copy.copy(table_keys)
attempts_to_retrieve_keys = 0
while len(cand_keys_list) != 0:
attempts = 0

# TODO: TRON-2363 - We should refactor this to not consume attempts when we are still making progress
while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES:
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
responses = [
executor.submit(
Expand All @@ -106,20 +136,35 @@ def _get_items(self, table_keys: list) -> object:
cand_keys_list = []
for resp in concurrent.futures.as_completed(responses):
try:
items.extend(resp.result()["Responses"][self.name])
# add any potential unprocessed keys to the thread pool
if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS:
cand_keys_list.extend(resp.result()["UnprocessedKeys"][self.name]["Keys"])
elif attempts_to_retrieve_keys >= MAX_ATTEMPTS:
failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"]
error = Exception(
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}"
)
raise error
except Exception as e:
result = resp.result()
items.extend(result.get("Responses", {}).get(self.name, []))

# If DynamoDB returns unprocessed keys, we need to collect them and retry
unprocessed_keys = result.get("UnprocessedKeys", {}).get(self.name, {}).get("Keys", [])
if unprocessed_keys:
cand_keys_list.extend(unprocessed_keys)
except botocore.exceptions.ClientError as e:
log.exception(f"ClientError during batch_get_item: {e.response}")
raise
except Exception:
log.exception("Encountered issues retrieving data from DynamoDB")
raise e
attempts_to_retrieve_keys += 1
raise
if cand_keys_list:
# We use _calculate_backoff_delay to get a delay that increases exponentially
# with each retry. These retry attempts are distinct from the boto3 retry_config
# and are used specifically to handle unprocessed keys.
attempts += 1
delay = self._calculate_backoff_delay(attempts)
log.warning(
f"Attempt {attempts}/{MAX_UNPROCESSED_KEYS_RETRIES} - "
f"Retrying {len(cand_keys_list)} unprocessed keys after {delay}s delay."
)
time.sleep(delay)
if cand_keys_list:
msg = f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys_list}\n from dynamodb after {MAX_UNPROCESSED_KEYS_RETRIES} retries."
log.error(msg)

raise KeyError(msg)
return items

def _get_first_partitions(self, keys: list):
Expand Down Expand Up @@ -291,12 +336,17 @@ def _save_loop(self):
def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
"""
Partition the item and write up to MAX_TRANSACT_WRITE_ITEMS
partitions atomically. Retry up to 3 times on failure.
partitions atomically using TransactWriteItems.
The function examines the size of pickled_val and json_val,
splitting them into multiple segments based on OBJECT_SIZE,
storing each segment under the same partition key.
Examine the size of `pickled_val` and `json_val`, and
splice them into different parts based on `OBJECT_SIZE`
with different sort keys, and save them under the same
partition key built.
It relies on the boto3/botocore retry_config to handle
certain errors (e.g. throttling). If an error is not
addressed by boto3's internal logic, the transaction fails
and raises an exception. It is the caller's responsibility
to implement further retries.
"""
start = time.time()

Expand Down Expand Up @@ -337,25 +387,21 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
"N": str(num_json_val_partitions),
}

count = 0
items.append(item)

while len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
# We want to write the items when we've either reached the max number of items
# for a transaction, or when we're done processing all partitions
if len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
try:
self.client.transact_write_items(TransactItems=items)
items = []
break # exit the while loop on successful writing
except Exception as e:
count += 1
if count > 3:
timer(
name="tron.dynamodb.setitem",
delta=time.time() - start,
)
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
raise e
else:
log.warning(f"Got error while saving {key}, trying again: {repr(e)}")
except Exception:
timer(
name="tron.dynamodb.setitem",
delta=time.time() - start,
)
log.exception(f"Failed to save partition for key: {key}")
raise
timer(
name="tron.dynamodb.setitem",
delta=time.time() - start,
Expand Down

0 comments on commit ed5c7a7

Please sign in to comment.