Skip to content

Commit

Permalink
implement sts_token_buffer_time attribute for transport_options to up…
Browse files Browse the repository at this point in the history
…date token earlier than expiration time
  • Loading branch information
eisichenko committed Jan 14, 2025
1 parent a0175b0 commit e00da33
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
20 changes: 16 additions & 4 deletions kombu/transport/SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
},
}
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
'sts_token_timeout': 900 # optional
'sts_token_timeout': 900, # optional
'sts_token_buffer_time': 0 # optional
}
Note that FIFO and standard queues must be named accordingly (the name of
Expand All @@ -91,6 +92,9 @@
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
to 900 seconds. After the mentioned period, a new token will be created.
sts_token_buffer_time (seconds) is the time by which you want to refresh your token
earlier than its actual expiration time, defaults to 0 (no time buffer will be added),
should be less than sts_token_timeout.
Expand Down Expand Up @@ -136,7 +140,7 @@
import socket
import string
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from queue import Empty

from botocore.client import Config
Expand Down Expand Up @@ -777,10 +781,18 @@ def _handle_sts_session(self, queue, q):
return self._new_predefined_queue_client_with_sts_session(queue, region)
return self._predefined_queue_clients[queue]

def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0):
credentials = self.generate_sts_session_token(role_arn, token_expiry_seconds)
if token_buffer_seconds and token_buffer_seconds < token_expiry_seconds:
credentials["Expiration"] -= timedelta(seconds=token_buffer_seconds)
return credentials

def _new_predefined_queue_client_with_sts_session(self, queue, region):
sts_creds = self.generate_sts_session_token(
sts_creds = self.generate_sts_session_token_with_buffer(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.transport_options.get('sts_token_timeout', 900),
self.transport_options.get('sts_token_buffer_time', 0),
)
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=region,
Expand Down
75 changes: 75 additions & 0 deletions t/unit/transport/test_SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,43 @@ def test_sts_new_session(self):
# Assert
mock_generate_sts_session_token.assert_called_once()

def test_sts_new_session_with_buffer_time(self):
# Arrange
sts_token_timeout = 900
sts_token_buffer_time = 60
connection = Connection(transport=SQS.Transport, transport_options={
'predefined_queues': example_predefined_queues,
'sts_role_arn': 'test::arn',
'sts_token_timeout': sts_token_timeout,
'sts_token_buffer_time': sts_token_buffer_time,
})
channel = connection.channel()
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
queue_name = 'queue-1'

mock_generate_sts_session_token = Mock()
mock_new_sqs_client = Mock()
channel.new_sqs_client = mock_new_sqs_client

expiration_time = datetime.utcnow() + timedelta(seconds=sts_token_timeout)

mock_generate_sts_session_token.side_effect = [
{
'Expiration': expiration_time,
'SessionToken': 123,
'AccessKeyId': 123,
'SecretAccessKey': 123
}
]
channel.generate_sts_session_token = mock_generate_sts_session_token

# Act
sqs(queue=queue_name)

# Assert
mock_generate_sts_session_token.assert_called_once()
assert channel.sts_expiration == expiration_time - timedelta(seconds=sts_token_buffer_time)

def test_sts_session_expired(self):
# Arrange
connection = Connection(transport=SQS.Transport, transport_options={
Expand Down Expand Up @@ -966,6 +1003,44 @@ def test_sts_session_expired(self):
# Assert
mock_generate_sts_session_token.assert_called_once()

def test_sts_session_expired_with_buffer_time(self):
# Arrange
sts_token_timeout = 900
sts_token_buffer_time = 60
connection = Connection(transport=SQS.Transport, transport_options={
'predefined_queues': example_predefined_queues,
'sts_role_arn': 'test::arn',
'sts_token_timeout': sts_token_timeout,
'sts_token_buffer_time': sts_token_buffer_time,
})
channel = connection.channel()
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
channel.sts_expiration = datetime.utcnow() - timedelta(days=1)
queue_name = 'queue-1'

mock_generate_sts_session_token = Mock()
mock_new_sqs_client = Mock()
channel.new_sqs_client = mock_new_sqs_client

expiration_time = datetime.utcnow() + timedelta(seconds=sts_token_timeout)

mock_generate_sts_session_token.side_effect = [
{
'Expiration': expiration_time,
'SessionToken': 123,
'AccessKeyId': 123,
'SecretAccessKey': 123
}
]
channel.generate_sts_session_token = mock_generate_sts_session_token

# Act
sqs(queue=queue_name)

# Assert
mock_generate_sts_session_token.assert_called_once()
assert channel.sts_expiration == expiration_time - timedelta(seconds=sts_token_buffer_time)

def test_sts_session_not_expired(self):
# Arrange
connection = Connection(transport=SQS.Transport, transport_options={
Expand Down

0 comments on commit e00da33

Please sign in to comment.