Skip to content

Commit b45c6c8

Browse files
committed
implement sts_token_buffer_time attribute for transport_options to update token earlier than expiration time
1 parent 4c64cdd commit b45c6c8

File tree

1 file changed

+37
-29
lines changed

1 file changed

+37
-29
lines changed

kombu/transport/SQS.py

+37-29
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@
7676
},
7777
}
7878
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
79-
'sts_token_timeout': 900 # optional
79+
'sts_token_timeout': 900, # optional
80+
'sts_token_buffer_time': 60 # optional
8081
}
8182
8283
Note that FIFO and standard queues must be named accordingly (the name of
@@ -91,6 +92,9 @@
9192
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
9293
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
9394
to 900 seconds. After the mentioned period, a new token will be created.
95+
sts_token_buffer_time (seconds) is the time by which you want to refresh your token
96+
earlier than its actual expiration time, defaults to 0 (no time buffer will be added),
97+
should be less than sts_token_timeout.
9498
9599
96100
@@ -136,7 +140,7 @@
136140
import socket
137141
import string
138142
import uuid
139-
from datetime import datetime
143+
from datetime import datetime, timedelta
140144
from queue import Empty
141145

142146
from botocore.client import Config
@@ -765,34 +769,38 @@ def sqs(self, queue=None):
765769
)
766770
return c
767771

772+
def _refresh_sqs_client(self, queue, q):
773+
sts_creds = self.generate_sts_session_token_with_buffer(
774+
self.transport_options.get('sts_role_arn'),
775+
self.transport_options.get('sts_token_timeout', 900),
776+
self.transport_options.get('sts_token_buffer_time', 0),
777+
)
778+
self.sts_expiration = sts_creds['Expiration']
779+
self._predefined_queue_clients[queue] = self.new_sqs_client(
780+
region=q.get('region', self.region),
781+
access_key_id=sts_creds['AccessKeyId'],
782+
secret_access_key=sts_creds['SecretAccessKey'],
783+
session_token=sts_creds['SessionToken'],
784+
)
785+
return self._predefined_queue_clients[queue]
786+
768787
def _handle_sts_session(self, queue, q):
769-
if not hasattr(self, 'sts_expiration'): # STS token - token init
770-
sts_creds = self.generate_sts_session_token(
771-
self.transport_options.get('sts_role_arn'),
772-
self.transport_options.get('sts_token_timeout', 900))
773-
self.sts_expiration = sts_creds['Expiration']
774-
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
775-
region=q.get('region', self.region),
776-
access_key_id=sts_creds['AccessKeyId'],
777-
secret_access_key=sts_creds['SecretAccessKey'],
778-
session_token=sts_creds['SessionToken'],
779-
)
780-
return c
781-
# STS token - refresh if expired
782-
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
783-
sts_creds = self.generate_sts_session_token(
784-
self.transport_options.get('sts_role_arn'),
785-
self.transport_options.get('sts_token_timeout', 900))
786-
self.sts_expiration = sts_creds['Expiration']
787-
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
788-
region=q.get('region', self.region),
789-
access_key_id=sts_creds['AccessKeyId'],
790-
secret_access_key=sts_creds['SecretAccessKey'],
791-
session_token=sts_creds['SessionToken'],
792-
)
793-
return c
794-
else: # STS token - ruse existing
795-
return self._predefined_queue_clients[queue]
788+
"""
789+
Refreshes the SQS client with a new token on STS token initialization
790+
or expiration. Otherwise, using cached client.
791+
"""
792+
if (
793+
not hasattr(self, 'sts_expiration') or
794+
self.sts_expiration.replace(tzinfo=None) < datetime.utcnow()
795+
):
796+
return self._refresh_sqs_client(queue, q)
797+
return self._predefined_queue_clients[queue]
798+
799+
def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0):
800+
credentials = self.generate_sts_session_token(role_arn, token_expiry_seconds)
801+
if token_buffer_seconds and token_buffer_seconds < token_expiry_seconds:
802+
credentials["Expiration"] -= timedelta(seconds=token_buffer_seconds)
803+
return credentials
796804

797805
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
798806
sts_client = boto3.client('sts')

0 commit comments

Comments
 (0)