diff --git a/Pipfile b/Pipfile index 1a12aec..2e71f17 100644 --- a/Pipfile +++ b/Pipfile @@ -8,6 +8,8 @@ verify_ssl = true [packages] chaostoolkit = "*" chaostoolkit-aws = "*" +boto3 = "*" +ratelimiter = "*" [requires] python_version = "3.6" diff --git a/drivers/the_publisher.py b/drivers/the_publisher.py index eff31cc..50173b4 100755 --- a/drivers/the_publisher.py +++ b/drivers/the_publisher.py @@ -1,50 +1,69 @@ #!/usr/bin/env python -import boto3 -from time import time -import sys -from datetime import datetime as dt import json -from random import random, shuffle +from datetime import datetime as dt from itertools import product +from random import random, shuffle from string import ascii_uppercase +from time import time + +import boto3 +from ratelimiter import RateLimiter from aws_resource_names import S3_BUCKET_NAME -# Method 2: Client.put_object() -s3 = boto3.client('s3') +s3 = boto3.client("s3") run_flag = True obj_count = 0 iter_obj_count = 0 -obj_limit = 2 # number of objects per second to put -err_rate = 0.01 # what percentage of messages should be flawed, 0.1 == 10% of messages will have syntax errors -start_time = time () -last_print_time = time () -symbols = [''.join(i) for i in product (ascii_uppercase, repeat=4)] -shuffle (symbols) -print ("Publishing messages for {} symbols".format (len (symbols))) +obj_limit = 2 # number of objects per second to put +err_rate = 0.01 # what percentage of messages should be flawed, 0.1 == 10% of messages will have syntax errors +start_time = time() +last_print_time = time() +symbols = ["".join(i) for i in product(ascii_uppercase, repeat=4)] +shuffle(symbols) +print("Publishing messages for {} symbols".format(len(symbols))) message_id = 0 +rate_limiter = RateLimiter(max_calls=obj_limit) + try: while run_flag: - if iter_obj_count <= obj_limit * 10: - symbol = symbols.pop () + with rate_limiter: + symbol = symbols.pop() message_id += 1 - obj_name = 'data_object_msg-{}.json'.format (message_id) - data = {'symbol': symbol, 'messageId': message_id, 'value': 10, 'objectName': obj_name, 'submissionDate': dt.now().strftime ('%d-%b-%Y %H:%M:%S'), 'author': 'the_publisher.py', 'version': 1.1} - body = json.dumps (data) - if random () < err_rate: - body = body.replace ('"','',1) # if we should inject an erroneous message send malformed JSON with a syntax error - s3.put_object(Body=body, Bucket=S3_BUCKET_NAME, Key='input/{}'.format (obj_name)) + obj_name = "data_object_msg-{}.json".format(message_id) + data = { + "symbol": symbol, + "messageId": message_id, + "value": 10, + "objectName": obj_name, + "submissionDate": dt.now().strftime("%d-%b-%Y %H:%M:%S"), + "author": "the_publisher.py", + "version": 1.1, + } + body = json.dumps(data) + if random() < err_rate: + body = body.replace( + '"', "", 1 + ) # if we should inject an erroneous message send malformed JSON with a syntax error + s3.put_object( + Body=body, Bucket=S3_BUCKET_NAME, Key="input/{}".format(obj_name) + ) obj_count += 1 iter_obj_count += 1 - if (int(time () - start_time) % 10) == 0 and (time () - last_print_time) > 10: - print ("{}: Pushed {} objects for a total of {} objects".format (dt.now ().strftime ('%Y-%b-%d %H:%M:%S'), iter_obj_count, obj_count)) - last_print_time = time () + if iter_obj_count % (obj_limit * 10) == 0: + print( + "{}: Pushed {} objects for a total of {} objects".format( + dt.now().strftime("%Y-%b-%d %H:%M:%S"), iter_obj_count, obj_count + ) + ) + last_print_time = time() iter_obj_count = 0 except KeyboardInterrupt: - print ("Pushed a total of {} objects; exiting...".format (obj_count)) - sys.exit (0) + print("Pushed a total of {} objects; exiting...".format(obj_count)) +except IndexError: + print("Pushed all symbols, exiting.") diff --git a/drivers/the_subscriber.py b/drivers/the_subscriber.py index 0417ad8..9830ebb 100755 --- a/drivers/the_subscriber.py +++ b/drivers/the_subscriber.py @@ -1,38 +1,46 @@ #!/usr/bin/env python -import boto3 -from time import time -from datetime import datetime as dt -import json import sys +from datetime import datetime as dt +from time import time + +import boto3 +from ratelimiter import RateLimiter from aws_resource_names import SQS_QUEUE_NAME -sqs = boto3.client('sqs') -queue_url = sqs.get_queue_url (QueueName=SQS_QUEUE_NAME) -queue_url = queue_url['QueueUrl'] +sqs = boto3.client("sqs") +queue_url = sqs.get_queue_url(QueueName=SQS_QUEUE_NAME) +queue_url = queue_url["QueueUrl"] run_flag = True obj_count = 0 iter_obj_count = 0 -obj_limit = 2 # number of objects per second to get -start_time = time () -last_print_time = time () +obj_limit = 2 # number of objects per second to get +start_time = time() +rate_limiter = RateLimiter(max_calls=obj_limit) while run_flag: try: - if iter_obj_count <= obj_limit * 10: - resp = sqs.receive_message (QueueUrl=queue_url, WaitTimeSeconds=1, MaxNumberOfMessages=obj_limit) - if 'Messages' in resp: - for msg in resp['Messages']: - sqs.delete_message (QueueUrl=queue_url, ReceiptHandle=msg['ReceiptHandle']) + with rate_limiter: + resp = sqs.receive_message( + QueueUrl=queue_url, WaitTimeSeconds=1, MaxNumberOfMessages=obj_limit + ) + if "Messages" in resp: + for msg in resp["Messages"]: + sqs.delete_message( + QueueUrl=queue_url, ReceiptHandle=msg["ReceiptHandle"] + ) obj_count += 1 iter_obj_count += 1 - if (int(time () - start_time) % 10) == 0 and (time () - last_print_time) > 10: - print ("{}: Retrieved {} objects for a total of {} objects".format (dt.now ().strftime ('%Y-%b-%d %H:%M:%S'), iter_obj_count, obj_count)) - last_print_time = time () + if obj_count % (obj_limit * 10) == 0: + print( + "{}: Retrieved {} objects for a total of {} objects".format( + dt.now().strftime("%Y-%b-%d %H:%M:%S"), iter_obj_count, obj_count + ) + ) iter_obj_count = 0 except KeyboardInterrupt: - print ("Retrieved a total of {} objects; exiting...".format (obj_count)) - sys.exit (0) + print("Retrieved a total of {} objects; exiting...".format(obj_count)) + sys.exit(0)