diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index 67c63c1a..d2318851 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -1,7 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import random import threading +import time from collections.abc import Callable from dataclasses import dataclass from enum import Enum @@ -480,3 +482,113 @@ def record_success(self, *, token: retries_interface.RetryToken) -> None: def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy": return self + + +class TokenBucket: + """Token bucket for rate limiting with configurable fill rate. + + TokenBucket provides a collection of arbitrary tokens while managing issuance + and refilling over time. This is controlled by a fill rate that can be variably + adjusted. When tokens aren't available, the bucket will enforce a delay before + attempting to reacquire tokens until one is available or the defined timeout is + reached. + """ + + MIN_FILL_RATE = 0.5 # Minimum allowed fill rate (0.5 tokens/second) + MIN_CAPACITY = 1.0 # Minimum allowed bucket capacity (1.0 tokens) + DEFAULT_TIMEOUT = 30.0 # Default timeout for token acquisition (30.0 seconds) + + def __init__( + self, + *, + curr_capacity: float = MIN_CAPACITY, + fill_rate: float = MIN_FILL_RATE, + timeout: float = DEFAULT_TIMEOUT, + ): + """Initialize a new TokenBucket. + + :param curr_capacity: Initial number of tokens in the bucket. + :param fill_rate: Rate at which tokens are added to the bucket (tokens/second). + :param timeout: Maximum time to wait for token acquisition before + raising TimeoutError. + """ + self._curr_capacity: float = max(curr_capacity, self.MIN_CAPACITY) + self._max_capacity: float = self._curr_capacity + self._fill_rate: float = max(fill_rate, self.MIN_FILL_RATE) + self._timeout = timeout + self._last_timestamp: float = time.monotonic() + self._lock = asyncio.Lock() + + async def acquire(self, amount: float) -> None: + """Acquire tokens from the bucket. + + If sufficient tokens are available, they are immediately consumed and the + method returns. If insufficient tokens are available, the method will wait + until enough tokens have been refilled or the timeout is reached. + + :param amount: Number of tokens to acquire. + + :raises TimeoutError: Acquisition took longer than the configured timeout. + """ + start_time = time.monotonic() + while True: + async with self._lock: + self._refill() + if self._curr_capacity >= amount: + self._curr_capacity -= amount + return + + elapsed = time.monotonic() - start_time + if elapsed > self._timeout: + # This will be caught in retry strategy and used as part of the retry count + raise TimeoutError( + f"Failed to acquire {amount} tokens within {self._timeout}s" + ) + wait_time = (amount - self._curr_capacity) / self._fill_rate + await asyncio.sleep(wait_time) + + def _refill(self) -> None: + curr_time = time.monotonic() + elapsed = curr_time - self._last_timestamp + refill_amount = elapsed * self._fill_rate + self._curr_capacity = min( + self._max_capacity, self._curr_capacity + refill_amount + ) + self._last_timestamp = curr_time + + async def update_rate(self, rate: float) -> None: + """Update the bucket's fill rate, maximum capacity and current capacity (if its + greater than maximum capacity). + + :param rate: New fill rate (tokens/second). It won't be less than MIN_FILL_RATE. + Current capacity will be reduced if it exceeds the new maximum capacity. + """ + async with self._lock: + self._refill() + self._fill_rate = max(rate, self.MIN_FILL_RATE) + self._max_capacity = max(rate, self.MIN_CAPACITY) + self._curr_capacity = min(self._curr_capacity, self._max_capacity) + + @property + def current_capacity(self) -> float: + """Get the current number of tokens in the bucket. + + :return: The current token count as of the last refill operation. + """ + return self._curr_capacity + + @property + def max_capacity(self) -> float: + """Get the maximum capacity of the bucket. + + :return: The maximum number of tokens the bucket can hold. + """ + return self._max_capacity + + @property + def fill_rate(self) -> float: + """Get the current fill rate of the bucket. + + :return: The rate at which tokens are added to the bucket (tokens/second). + """ + return self._fill_rate diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index c36c5b75..6341bd89 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -1,5 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import patch + import pytest from smithy_core.exceptions import CallError, RetryError from smithy_core.retries import ExponentialBackoffJitterType as EBJT @@ -10,6 +12,7 @@ SimpleRetryStrategy, StandardRetryQuota, StandardRetryStrategy, + TokenBucket, ) @@ -275,3 +278,86 @@ async def test_retry_strategy_resolver_rejects_invalid_type() -> None: match="retry_strategy must be RetryStrategy, RetryStrategyOptions, or None", ): await resolver.resolve_retry_strategy(retry_strategy="invalid") # type: ignore + + +class TestTokenBucket: + @pytest.mark.asyncio + async def test_initial_state(self): + token_bucket = TokenBucket() + assert token_bucket.current_capacity == token_bucket.MIN_CAPACITY + assert token_bucket.max_capacity == token_bucket.MIN_CAPACITY + assert token_bucket.fill_rate == token_bucket.MIN_FILL_RATE + + @pytest.mark.asyncio + async def test_acquire_succeeds_immediately_within_capacity(self): + token_bucket = TokenBucket() + + with patch("asyncio.sleep") as mock_sleep: + await token_bucket.acquire(1) + mock_sleep.assert_not_called() + + assert token_bucket.current_capacity == 0 + + @pytest.mark.asyncio + async def test_acquire_waits_when_capacity_insufficient(self): + token_bucket = TokenBucket(fill_rate=1.0) + await token_bucket.acquire(1) + + with patch("asyncio.sleep") as mock_sleep: + await token_bucket.acquire(1) + mock_sleep.assert_called() + + assert token_bucket.current_capacity == 0.0 + + @pytest.mark.asyncio + async def test_update_bucket_updates_rate(self): + token_bucket = TokenBucket() + + await token_bucket.update_rate(5.0) + assert token_bucket.fill_rate == 5.0 + assert token_bucket.max_capacity == 5.0 + assert token_bucket.current_capacity == 1.0 + + @pytest.mark.asyncio + async def test_rate_can_never_be_zero(self): + token_bucket = TokenBucket() + await token_bucket.update_rate(0.0) + + assert token_bucket.fill_rate != 0.0 + + @pytest.mark.asyncio + async def test_refill_caps_at_max_capacity(self): + token_bucket = TokenBucket() + # Max and current capacity of the bucket is set to 1.0 initially + await token_bucket.update_rate(10.0) + + async with token_bucket._lock: # type: ignore + token_bucket._refill() # type: ignore + + assert round(token_bucket.current_capacity, 1) == 1.0 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "actions,expected_capacity", + [ + ([("acquire", 1)], 0.0), + ([("acquire", 1), ("update", 4)], 1.0), + ([("acquire", 1), ("update", 4), ("acquire", 1)], 3.0), + ([("acquire", 1), ("update", 4), ("acquire", 1), ("acquire", 1)], 3.0), + ], + ) + async def test_multiple_refills_over_time( + self, actions: list[tuple[str, int]], expected_capacity: float + ): + time_values = [0.0, 1.0, 1.0, 1.5, 1.5, 4.0, 4.0, 5.0] + + with patch("time.monotonic", side_effect=time_values): + token_bucket = TokenBucket(curr_capacity=0, fill_rate=2.0) + + for action, value in actions: + if action == "acquire": + await token_bucket.acquire(value) + elif action == "update": + await token_bucket.update_rate(value) + + assert token_bucket.current_capacity == expected_capacity