|
76 | 76 | },
|
77 | 77 | }
|
78 | 78 | 'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
|
79 |
| - 'sts_token_timeout': 900 # optional |
| 79 | + 'sts_token_timeout': 900, # optional |
| 80 | + 'sts_token_buffer_time': 60 # optional |
80 | 81 | }
|
81 | 82 |
|
82 | 83 | Note that FIFO and standard queues must be named accordingly (the name of
|
|
91 | 92 | sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
|
92 | 93 | to access with. sts_token_timeout is the token timeout, defaults (and minimum)
|
93 | 94 | to 900 seconds. After the mentioned period, a new token will be created.
|
| 95 | +sts_token_buffer_time (seconds) is the time by which you want to refresh your token |
| 96 | +earlier than its actual expiration time, defaults to 0 (no time buffer will be added), |
| 97 | +should be less than sts_token_timeout. |
94 | 98 |
|
95 | 99 |
|
96 | 100 |
|
|
136 | 140 | import socket
|
137 | 141 | import string
|
138 | 142 | import uuid
|
139 |
| -from datetime import datetime |
| 143 | +from datetime import datetime, timedelta |
140 | 144 | from queue import Empty
|
141 | 145 |
|
142 | 146 | from botocore.client import Config
|
@@ -765,34 +769,38 @@ def sqs(self, queue=None):
|
765 | 769 | )
|
766 | 770 | return c
|
767 | 771 |
|
| 772 | + def _refresh_sqs_client(self, queue, q): |
| 773 | + sts_creds = self.generate_sts_session_token_with_buffer( |
| 774 | + self.transport_options.get('sts_role_arn'), |
| 775 | + self.transport_options.get('sts_token_timeout', 900), |
| 776 | + self.transport_options.get('sts_token_buffer_time', 0), |
| 777 | + ) |
| 778 | + self.sts_expiration = sts_creds['Expiration'] |
| 779 | + self._predefined_queue_clients[queue] = self.new_sqs_client( |
| 780 | + region=q.get('region', self.region), |
| 781 | + access_key_id=sts_creds['AccessKeyId'], |
| 782 | + secret_access_key=sts_creds['SecretAccessKey'], |
| 783 | + session_token=sts_creds['SessionToken'], |
| 784 | + ) |
| 785 | + return self._predefined_queue_clients[queue] |
| 786 | + |
768 | 787 | def _handle_sts_session(self, queue, q):
|
769 |
| - if not hasattr(self, 'sts_expiration'): # STS token - token init |
770 |
| - sts_creds = self.generate_sts_session_token( |
771 |
| - self.transport_options.get('sts_role_arn'), |
772 |
| - self.transport_options.get('sts_token_timeout', 900)) |
773 |
| - self.sts_expiration = sts_creds['Expiration'] |
774 |
| - c = self._predefined_queue_clients[queue] = self.new_sqs_client( |
775 |
| - region=q.get('region', self.region), |
776 |
| - access_key_id=sts_creds['AccessKeyId'], |
777 |
| - secret_access_key=sts_creds['SecretAccessKey'], |
778 |
| - session_token=sts_creds['SessionToken'], |
779 |
| - ) |
780 |
| - return c |
781 |
| - # STS token - refresh if expired |
782 |
| - elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow(): |
783 |
| - sts_creds = self.generate_sts_session_token( |
784 |
| - self.transport_options.get('sts_role_arn'), |
785 |
| - self.transport_options.get('sts_token_timeout', 900)) |
786 |
| - self.sts_expiration = sts_creds['Expiration'] |
787 |
| - c = self._predefined_queue_clients[queue] = self.new_sqs_client( |
788 |
| - region=q.get('region', self.region), |
789 |
| - access_key_id=sts_creds['AccessKeyId'], |
790 |
| - secret_access_key=sts_creds['SecretAccessKey'], |
791 |
| - session_token=sts_creds['SessionToken'], |
792 |
| - ) |
793 |
| - return c |
794 |
| - else: # STS token - ruse existing |
795 |
| - return self._predefined_queue_clients[queue] |
| 788 | + """ |
| 789 | + Refreshes the SQS client with a new token on STS token initialization |
| 790 | + or expiration. Otherwise, using cached client. |
| 791 | + """ |
| 792 | + if ( |
| 793 | + not hasattr(self, 'sts_expiration') or |
| 794 | + self.sts_expiration.replace(tzinfo=None) < datetime.utcnow() |
| 795 | + ): |
| 796 | + return self._refresh_sqs_client(queue, q) |
| 797 | + return self._predefined_queue_clients[queue] |
| 798 | + |
| 799 | + def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0): |
| 800 | + credentials = self.generate_sts_session_token(role_arn, token_expiry_seconds) |
| 801 | + if token_buffer_seconds and token_buffer_seconds < token_expiry_seconds: |
| 802 | + credentials["Expiration"] -= timedelta(seconds=token_buffer_seconds) |
| 803 | + return credentials |
796 | 804 |
|
797 | 805 | def generate_sts_session_token(self, role_arn, token_expiry_seconds):
|
798 | 806 | sts_client = boto3.client('sts')
|
|
0 commit comments