From a46a6e088207475506b714b47c3475daf221b6c8 Mon Sep 17 00:00:00 2001 From: ubaskota <19787410+ubaskota@users.noreply.github.com> Date: Mon, 22 Dec 2025 12:21:50 -0500 Subject: [PATCH 1/5] Implement a token bucket used for adaptive retries --- .../smithy-core/src/smithy_core/retries.py | 73 +++++++++++ .../smithy-core/tests/unit/test_retries.py | 113 ++++++++++++++++++ 2 files changed, 186 insertions(+) diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index 67c63c1a..1d7ef733 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -1,7 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import random import threading +import time from collections.abc import Callable from dataclasses import dataclass from enum import Enum @@ -480,3 +482,74 @@ def record_success(self, *, token: retries_interface.RetryToken) -> None: def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy": return self + + +class TokenBucket: + MIN_FILL_RATE = 0.5 + MIN_CAPACITY = 1.0 + + def __init__( + self, + *, + curr_capacity: float | None = None, + max_capacity: float | None = None, + fill_rate: float | None = None, + ): + self._curr_capacity: float = ( + curr_capacity if curr_capacity is not None else self.MIN_CAPACITY + ) + self._max_capacity: float = ( + max_capacity + if max_capacity is not None + else max(self._curr_capacity, self.MIN_CAPACITY) + ) + self._fill_rate: float = ( + max(self.MIN_FILL_RATE, fill_rate) + if fill_rate is not None + else self.MIN_FILL_RATE + ) + self._last_timestamp: float | None = None + self._lock = asyncio.Lock() + + async def acquire(self, amount: float) -> None: + while True: + async with self._lock: + self._refill() + if self._curr_capacity >= amount: + self._curr_capacity -= amount + return + + wait_time = (amount - self._curr_capacity) / self._fill_rate + await asyncio.sleep(wait_time) + + def _refill(self) -> None: + curr_time = time.monotonic() + if self._last_timestamp is None: + self._last_timestamp = curr_time + return + + elapsed = curr_time - self._last_timestamp + refill_amount = elapsed * self._fill_rate + self._curr_capacity = min( + self._max_capacity, self._curr_capacity + refill_amount + ) + self._last_timestamp = curr_time + + async def update_bucket(self, rate: float) -> None: + async with self._lock: + self._refill() + self._fill_rate = max(rate, self.MIN_FILL_RATE) + self._max_capacity = max(rate, self.MIN_CAPACITY) + self._curr_capacity = min(self._curr_capacity, self._max_capacity) + + @property + def current_capacity(self) -> float: + return self._curr_capacity + + @property + def max_capacity(self) -> float: + return self._max_capacity + + @property + def fill_rate(self) -> float: + return self._fill_rate diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index c36c5b75..f697b37b 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -1,5 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio +import time +from unittest.mock import AsyncMock, patch + import pytest from smithy_core.exceptions import CallError, RetryError from smithy_core.retries import ExponentialBackoffJitterType as EBJT @@ -10,6 +14,7 @@ SimpleRetryStrategy, StandardRetryQuota, StandardRetryStrategy, + TokenBucket, ) @@ -275,3 +280,111 @@ async def test_retry_strategy_resolver_rejects_invalid_type() -> None: match="retry_strategy must be RetryStrategy, RetryStrategyOptions, or None", ): await resolver.resolve_retry_strategy(retry_strategy="invalid") # type: ignore + + +class TestTokenBucket: + @pytest.mark.asyncio + async def test_initial_state(self): + token_bucket = TokenBucket() + assert token_bucket.current_capacity == token_bucket.MIN_CAPACITY + assert token_bucket.max_capacity == token_bucket.MIN_CAPACITY + assert token_bucket.fill_rate == token_bucket.MIN_FILL_RATE + + @pytest.mark.asyncio + async def test_acquire_succeeds_immediately_within_capacity(self): + token_bucket = TokenBucket() + start_time = time.monotonic() + await token_bucket.acquire(1) + elapsed = time.monotonic() - start_time + + assert elapsed < 0.001 # Should be near instant + assert token_bucket.current_capacity == 0 + + @pytest.mark.asyncio + async def test_acquire_waits_when_capacity_insufficient(self): + token_bucket = TokenBucket(curr_capacity=0) + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + + async def side_effect(delay: float): + async with token_bucket._lock: # type: ignore + token_bucket._curr_capacity = 1.0 # type: ignore + + mock_sleep.side_effect = side_effect + await token_bucket.acquire(1) + assert mock_sleep.call_count == 1 + + actual_delay = mock_sleep.call_args[0][0] + assert actual_delay == pytest.approx(2.0, abs=0.05) # type: ignore + + @pytest.mark.asyncio + async def test_multiple_refills_over_time(self): + token_bucket = TokenBucket(curr_capacity=0, max_capacity=10, fill_rate=2.0) + + time_values = iter([1.0, 1.5, 4.0]) + with patch("time.monotonic", side_effect=lambda: next(time_values)): + token_bucket._last_timestamp = 0.0 # type: ignore + + async with token_bucket._lock: # type: ignore + token_bucket._refill() # type: ignore + assert token_bucket.current_capacity == pytest.approx(2.0, abs=0.05) # type: ignore + + with patch("time.monotonic", side_effect=lambda: next(time_values)): + async with token_bucket._lock: # type: ignore + token_bucket._refill() # type: ignore + assert token_bucket.current_capacity == pytest.approx(3.0, abs=0.05) # type: ignore + + with patch("time.monotonic", side_effect=lambda: next(time_values)): + async with token_bucket._lock: # type: ignore + token_bucket._refill() # type: ignore + assert token_bucket.current_capacity == pytest.approx(8.0, abs=0.05) # type: ignore + + @pytest.mark.asyncio + async def test_update_bucket_updates_capacity(self): + token_bucket = TokenBucket() + + await token_bucket.update_bucket(5.0) + assert token_bucket.fill_rate == 5.0 + assert token_bucket.max_capacity == 5.0 + assert token_bucket.current_capacity == 1.0 + + @pytest.mark.asyncio + async def test_rate_can_never_be_zero(self): + token_bucket = TokenBucket() + await token_bucket.update_bucket(0.0) + + assert token_bucket.fill_rate != 0.0 + + @pytest.mark.asyncio + async def test_refill_caps_at_max_capacity(self): + token_bucket = TokenBucket() + # Max and current capacity of the bucket is set to 1.0 initially + await token_bucket.update_bucket(10.0) + + async with token_bucket._lock: # type: ignore + token_bucket._refill() # type: ignore + + assert token_bucket.current_capacity == pytest.approx(1.0, abs=0.05) # type: ignore + + @pytest.mark.asyncio + async def test_many_concurrent_tasks_succeed(self): + token_bucket = TokenBucket(curr_capacity=2.0) + await token_bucket.update_bucket(4.0) + completed_tasks: list[int] = [] + + async def worker(worker_id: int): + await token_bucket.acquire(0.1) + completed_tasks.append(worker_id) + + try: + # At the fill rate of 4/second and acquire cost of 0.1, it should take + # around 2 seconds to process 100 tasks. + await asyncio.wait_for( + asyncio.gather(*[worker(i) for i in range(100)]), timeout=3 + ) + except TimeoutError: + pytest.fail("Deadlock detected: concurrent acquire operations timed out") + + assert len(completed_tasks) == 100 + assert len(set(completed_tasks)) == 100 + assert token_bucket.current_capacity >= 0.0 From 927a9fa99d998680e1d418e8cc4ff855199ccf01 Mon Sep 17 00:00:00 2001 From: ubaskota <19787410+ubaskota@users.noreply.github.com> Date: Mon, 29 Dec 2025 10:08:11 -0500 Subject: [PATCH 2/5] Address reviews --- .../smithy-core/src/smithy_core/retries.py | 44 +++++++------ .../smithy-core/tests/unit/test_retries.py | 64 +++++++------------ 2 files changed, 47 insertions(+), 61 deletions(-) diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index 1d7ef733..a8332c7e 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -487,31 +487,32 @@ def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy": class TokenBucket: MIN_FILL_RATE = 0.5 MIN_CAPACITY = 1.0 + DEFAULT_TIMEOUT = 30.0 def __init__( self, *, - curr_capacity: float | None = None, - max_capacity: float | None = None, - fill_rate: float | None = None, + curr_capacity: float = MIN_CAPACITY, + fill_rate: float = MIN_FILL_RATE, + default_timeout: float = DEFAULT_TIMEOUT, ): - self._curr_capacity: float = ( - curr_capacity if curr_capacity is not None else self.MIN_CAPACITY - ) - self._max_capacity: float = ( - max_capacity - if max_capacity is not None - else max(self._curr_capacity, self.MIN_CAPACITY) - ) - self._fill_rate: float = ( - max(self.MIN_FILL_RATE, fill_rate) - if fill_rate is not None - else self.MIN_FILL_RATE - ) - self._last_timestamp: float | None = None + self._curr_capacity: float = max(curr_capacity, self.MIN_CAPACITY) + self._max_capacity: float = self._curr_capacity + self._fill_rate: float = max(fill_rate, self.MIN_FILL_RATE) + self._default_timeout = default_timeout + self._last_timestamp: float = time.monotonic() self._lock = asyncio.Lock() async def acquire(self, amount: float) -> None: + """Acquire tokens from the bucket. + + Args: + amount: Number of tokens to acquire + + Raises: + TimeoutError: If acquisition takes longer than default_timeout + """ + start_time = time.monotonic() while True: async with self._lock: self._refill() @@ -519,15 +520,16 @@ async def acquire(self, amount: float) -> None: self._curr_capacity -= amount return + elapsed = time.monotonic() - start_time + if elapsed > self._default_timeout: + raise TimeoutError( + f"Failed to acquire {amount} tokens within {self._default_timeout}s" + ) wait_time = (amount - self._curr_capacity) / self._fill_rate await asyncio.sleep(wait_time) def _refill(self) -> None: curr_time = time.monotonic() - if self._last_timestamp is None: - self._last_timestamp = curr_time - return - elapsed = curr_time - self._last_timestamp refill_amount = elapsed * self._fill_rate self._curr_capacity = min( diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index f697b37b..2143bd58 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -1,8 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import asyncio -import time -from unittest.mock import AsyncMock, patch +from unittest.mock import patch import pytest from smithy_core.exceptions import CallError, RetryError @@ -293,51 +292,43 @@ async def test_initial_state(self): @pytest.mark.asyncio async def test_acquire_succeeds_immediately_within_capacity(self): token_bucket = TokenBucket() - start_time = time.monotonic() - await token_bucket.acquire(1) - elapsed = time.monotonic() - start_time - assert elapsed < 0.001 # Should be near instant + with patch("asyncio.sleep") as mock_sleep: + await token_bucket.acquire(1) + mock_sleep.assert_not_called() + assert token_bucket.current_capacity == 0 @pytest.mark.asyncio async def test_acquire_waits_when_capacity_insufficient(self): - token_bucket = TokenBucket(curr_capacity=0) - - with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: - - async def side_effect(delay: float): - async with token_bucket._lock: # type: ignore - token_bucket._curr_capacity = 1.0 # type: ignore + token_bucket = TokenBucket(fill_rate=1.0) + await token_bucket.acquire(1) - mock_sleep.side_effect = side_effect + with patch("asyncio.sleep") as mock_sleep: await token_bucket.acquire(1) - assert mock_sleep.call_count == 1 + mock_sleep.assert_called() - actual_delay = mock_sleep.call_args[0][0] - assert actual_delay == pytest.approx(2.0, abs=0.05) # type: ignore + assert token_bucket.current_capacity == 0.0 @pytest.mark.asyncio async def test_multiple_refills_over_time(self): - token_bucket = TokenBucket(curr_capacity=0, max_capacity=10, fill_rate=2.0) - - time_values = iter([1.0, 1.5, 4.0]) + time_values = iter([0.0, 1.0, 1.0, 1.5, 1.5, 4.0, 4.0, 5.0]) with patch("time.monotonic", side_effect=lambda: next(time_values)): - token_bucket._last_timestamp = 0.0 # type: ignore + token_bucket = TokenBucket(curr_capacity=0, fill_rate=2.0) + await token_bucket.acquire(1) + assert token_bucket.current_capacity == 0.0 - async with token_bucket._lock: # type: ignore - token_bucket._refill() # type: ignore - assert token_bucket.current_capacity == pytest.approx(2.0, abs=0.05) # type: ignore + with patch("time.monotonic", side_effect=lambda: next(time_values)): + await token_bucket.update_bucket(4) # Update the rate of refill + assert token_bucket.current_capacity == 1.0 with patch("time.monotonic", side_effect=lambda: next(time_values)): - async with token_bucket._lock: # type: ignore - token_bucket._refill() # type: ignore - assert token_bucket.current_capacity == pytest.approx(3.0, abs=0.05) # type: ignore + await token_bucket.acquire(1) + assert token_bucket.current_capacity == 3 with patch("time.monotonic", side_effect=lambda: next(time_values)): - async with token_bucket._lock: # type: ignore - token_bucket._refill() # type: ignore - assert token_bucket.current_capacity == pytest.approx(8.0, abs=0.05) # type: ignore + await token_bucket.acquire(1) + assert token_bucket.current_capacity == 3 @pytest.mark.asyncio async def test_update_bucket_updates_capacity(self): @@ -364,10 +355,10 @@ async def test_refill_caps_at_max_capacity(self): async with token_bucket._lock: # type: ignore token_bucket._refill() # type: ignore - assert token_bucket.current_capacity == pytest.approx(1.0, abs=0.05) # type: ignore + assert round(token_bucket.current_capacity, 1) == 1.0 @pytest.mark.asyncio - async def test_many_concurrent_tasks_succeed(self): + async def test_many_tasks_succeed(self): token_bucket = TokenBucket(curr_capacity=2.0) await token_bucket.update_bucket(4.0) completed_tasks: list[int] = [] @@ -376,14 +367,7 @@ async def worker(worker_id: int): await token_bucket.acquire(0.1) completed_tasks.append(worker_id) - try: - # At the fill rate of 4/second and acquire cost of 0.1, it should take - # around 2 seconds to process 100 tasks. - await asyncio.wait_for( - asyncio.gather(*[worker(i) for i in range(100)]), timeout=3 - ) - except TimeoutError: - pytest.fail("Deadlock detected: concurrent acquire operations timed out") + await asyncio.gather(*[worker(i) for i in range(100)]) assert len(completed_tasks) == 100 assert len(set(completed_tasks)) == 100 From 1ac8dabc1f0388d6c4ae0bff00d6b631c46a8f2a Mon Sep 17 00:00:00 2001 From: ubaskota <19787410+ubaskota@users.noreply.github.com> Date: Tue, 30 Dec 2025 17:19:41 -0500 Subject: [PATCH 3/5] Add docstrings to Token Bucket class and its public functions, and adddress comments --- .../smithy-core/src/smithy_core/retries.py | 58 +++++++++++++++++-- .../smithy-core/tests/unit/test_retries.py | 17 ------ 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index a8332c7e..13ac4ccc 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -485,6 +485,18 @@ def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy": class TokenBucket: + """ + Implements the token bucket algorithm for rate limiting with configurable fill + rate and capacity. Tokens are added to the bucket at a fixed rate (fill_rate). + Requests must acquire tokens before proceeding, and will wait if insufficient + tokens are available. + + Attributes: + MIN_FILL_RATE: Minimum allowed fill rate (0.5 tokens/second) + MIN_CAPACITY: Minimum allowed bucket capacity (1.0 tokens) + DEFAULT_TIMEOUT: Default timeout for token acquisition (30.0 seconds) + """ + MIN_FILL_RATE = 0.5 MIN_CAPACITY = 1.0 DEFAULT_TIMEOUT = 30.0 @@ -494,23 +506,35 @@ def __init__( *, curr_capacity: float = MIN_CAPACITY, fill_rate: float = MIN_FILL_RATE, - default_timeout: float = DEFAULT_TIMEOUT, + timeout: float = DEFAULT_TIMEOUT, ): + """Initialize a new TokenBucket. + + Args: + curr_capacity: Initial number of tokens in the bucket. + fill_rate: Rate at which tokens are added to the bucket (tokens/second). + timeout: Maximum time to wait for token acquisition before + raising TimeoutError. + """ self._curr_capacity: float = max(curr_capacity, self.MIN_CAPACITY) self._max_capacity: float = self._curr_capacity self._fill_rate: float = max(fill_rate, self.MIN_FILL_RATE) - self._default_timeout = default_timeout + self._timeout = timeout self._last_timestamp: float = time.monotonic() self._lock = asyncio.Lock() async def acquire(self, amount: float) -> None: """Acquire tokens from the bucket. + If sufficient tokens are available, they are immediately consumed and the + method returns. If insufficient tokens are available, the method will wait + until enough tokens have been refilled or the timeout is reached. + Args: - amount: Number of tokens to acquire + amount: Number of tokens to acquire. Raises: - TimeoutError: If acquisition takes longer than default_timeout + TimeoutError: If acquisition takes longer than the configured timeout. """ start_time = time.monotonic() while True: @@ -521,9 +545,9 @@ async def acquire(self, amount: float) -> None: return elapsed = time.monotonic() - start_time - if elapsed > self._default_timeout: + if elapsed > self._timeout: raise TimeoutError( - f"Failed to acquire {amount} tokens within {self._default_timeout}s" + f"Failed to acquire {amount} tokens within {self._timeout}s" ) wait_time = (amount - self._curr_capacity) / self._fill_rate await asyncio.sleep(wait_time) @@ -538,6 +562,13 @@ def _refill(self) -> None: self._last_timestamp = curr_time async def update_bucket(self, rate: float) -> None: + """Update the bucket's fill rate, maximum capacity and current capacity (if its + greater than maximum capacity). + + Args: + rate: New fill rate (tokens/second). It won't be less than MIN_FILL_RATE. + Current capacity will be reduced if it exceeds the new maximum capacity. + """ async with self._lock: self._refill() self._fill_rate = max(rate, self.MIN_FILL_RATE) @@ -546,12 +577,27 @@ async def update_bucket(self, rate: float) -> None: @property def current_capacity(self) -> float: + """Get the current number of tokens in the bucket. + + Returns: + The current token count as of the last refill operation. + """ return self._curr_capacity @property def max_capacity(self) -> float: + """Get the maximum capacity of the bucket. + + Returns: + The maximum number of tokens the bucket can hold. + """ return self._max_capacity @property def fill_rate(self) -> float: + """Get the current fill rate of the bucket. + + Returns: + The rate at which tokens are added to the bucket (tokens/second). + """ return self._fill_rate diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index 2143bd58..eab71533 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -1,6 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import asyncio from unittest.mock import patch import pytest @@ -356,19 +355,3 @@ async def test_refill_caps_at_max_capacity(self): token_bucket._refill() # type: ignore assert round(token_bucket.current_capacity, 1) == 1.0 - - @pytest.mark.asyncio - async def test_many_tasks_succeed(self): - token_bucket = TokenBucket(curr_capacity=2.0) - await token_bucket.update_bucket(4.0) - completed_tasks: list[int] = [] - - async def worker(worker_id: int): - await token_bucket.acquire(0.1) - completed_tasks.append(worker_id) - - await asyncio.gather(*[worker(i) for i in range(100)]) - - assert len(completed_tasks) == 100 - assert len(set(completed_tasks)) == 100 - assert token_bucket.current_capacity >= 0.0 From a591b3d86422873d637716cf9dd66952e2a5dd52 Mon Sep 17 00:00:00 2001 From: ubaskota <19787410+ubaskota@users.noreply.github.com> Date: Fri, 2 Jan 2026 17:34:58 -0500 Subject: [PATCH 4/5] Update a test and address reviews --- .../smithy-core/src/smithy_core/retries.py | 52 ++++++++---------- .../smithy-core/tests/unit/test_retries.py | 54 ++++++++++--------- 2 files changed, 51 insertions(+), 55 deletions(-) diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index 13ac4ccc..ec542924 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -486,20 +486,16 @@ def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy": class TokenBucket: """ - Implements the token bucket algorithm for rate limiting with configurable fill - rate and capacity. Tokens are added to the bucket at a fixed rate (fill_rate). - Requests must acquire tokens before proceeding, and will wait if insufficient - tokens are available. - - Attributes: - MIN_FILL_RATE: Minimum allowed fill rate (0.5 tokens/second) - MIN_CAPACITY: Minimum allowed bucket capacity (1.0 tokens) - DEFAULT_TIMEOUT: Default timeout for token acquisition (30.0 seconds) + TokenBucket provides a collection of arbitrary tokens while managing issuance + and refilling over time. This is controlled by a fill rate that can be variably + adjusted. When tokens aren't available, the bucket will enforce a delay before + attempting to reacquire tokens until one is available or the defined timeout is + reached. """ - MIN_FILL_RATE = 0.5 - MIN_CAPACITY = 1.0 - DEFAULT_TIMEOUT = 30.0 + MIN_FILL_RATE = 0.5 # Minimum allowed fill rate (0.5 tokens/second) + MIN_CAPACITY = 1.0 # Minimum allowed bucket capacity (1.0 tokens) + DEFAULT_TIMEOUT = 30.0 # Default timeout for token acquisition (30.0 seconds) def __init__( self, @@ -510,11 +506,10 @@ def __init__( ): """Initialize a new TokenBucket. - Args: - curr_capacity: Initial number of tokens in the bucket. - fill_rate: Rate at which tokens are added to the bucket (tokens/second). - timeout: Maximum time to wait for token acquisition before - raising TimeoutError. + :param curr_capacity: Initial number of tokens in the bucket. + :param fill_rate: Rate at which tokens are added to the bucket (tokens/second). + :param timeout: Maximum time to wait for token acquisition before + raising TimeoutError. """ self._curr_capacity: float = max(curr_capacity, self.MIN_CAPACITY) self._max_capacity: float = self._curr_capacity @@ -530,11 +525,9 @@ async def acquire(self, amount: float) -> None: method returns. If insufficient tokens are available, the method will wait until enough tokens have been refilled or the timeout is reached. - Args: - amount: Number of tokens to acquire. + :param amount: Number of tokens to acquire. - Raises: - TimeoutError: If acquisition takes longer than the configured timeout. + :raises TimeoutError: Acquisition took longer than the configured timeout. """ start_time = time.monotonic() while True: @@ -546,6 +539,7 @@ async def acquire(self, amount: float) -> None: elapsed = time.monotonic() - start_time if elapsed > self._timeout: + # This will be caught in retry strategy and used as part of the retry count raise TimeoutError( f"Failed to acquire {amount} tokens within {self._timeout}s" ) @@ -561,13 +555,12 @@ def _refill(self) -> None: ) self._last_timestamp = curr_time - async def update_bucket(self, rate: float) -> None: + async def update_rate(self, rate: float) -> None: """Update the bucket's fill rate, maximum capacity and current capacity (if its greater than maximum capacity). - Args: - rate: New fill rate (tokens/second). It won't be less than MIN_FILL_RATE. - Current capacity will be reduced if it exceeds the new maximum capacity. + :param rate: New fill rate (tokens/second). It won't be less than MIN_FILL_RATE. + Current capacity will be reduced if it exceeds the new maximum capacity. """ async with self._lock: self._refill() @@ -579,8 +572,7 @@ async def update_bucket(self, rate: float) -> None: def current_capacity(self) -> float: """Get the current number of tokens in the bucket. - Returns: - The current token count as of the last refill operation. + :return: The current token count as of the last refill operation. """ return self._curr_capacity @@ -588,8 +580,7 @@ def current_capacity(self) -> float: def max_capacity(self) -> float: """Get the maximum capacity of the bucket. - Returns: - The maximum number of tokens the bucket can hold. + :return: The maximum number of tokens the bucket can hold. """ return self._max_capacity @@ -597,7 +588,6 @@ def max_capacity(self) -> float: def fill_rate(self) -> float: """Get the current fill rate of the bucket. - Returns: - The rate at which tokens are added to the bucket (tokens/second). + :return: The rate at which tokens are added to the bucket (tokens/second). """ return self._fill_rate diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index eab71533..6341bd89 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -310,30 +310,10 @@ async def test_acquire_waits_when_capacity_insufficient(self): assert token_bucket.current_capacity == 0.0 @pytest.mark.asyncio - async def test_multiple_refills_over_time(self): - time_values = iter([0.0, 1.0, 1.0, 1.5, 1.5, 4.0, 4.0, 5.0]) - with patch("time.monotonic", side_effect=lambda: next(time_values)): - token_bucket = TokenBucket(curr_capacity=0, fill_rate=2.0) - await token_bucket.acquire(1) - assert token_bucket.current_capacity == 0.0 - - with patch("time.monotonic", side_effect=lambda: next(time_values)): - await token_bucket.update_bucket(4) # Update the rate of refill - assert token_bucket.current_capacity == 1.0 - - with patch("time.monotonic", side_effect=lambda: next(time_values)): - await token_bucket.acquire(1) - assert token_bucket.current_capacity == 3 - - with patch("time.monotonic", side_effect=lambda: next(time_values)): - await token_bucket.acquire(1) - assert token_bucket.current_capacity == 3 - - @pytest.mark.asyncio - async def test_update_bucket_updates_capacity(self): + async def test_update_bucket_updates_rate(self): token_bucket = TokenBucket() - await token_bucket.update_bucket(5.0) + await token_bucket.update_rate(5.0) assert token_bucket.fill_rate == 5.0 assert token_bucket.max_capacity == 5.0 assert token_bucket.current_capacity == 1.0 @@ -341,7 +321,7 @@ async def test_update_bucket_updates_capacity(self): @pytest.mark.asyncio async def test_rate_can_never_be_zero(self): token_bucket = TokenBucket() - await token_bucket.update_bucket(0.0) + await token_bucket.update_rate(0.0) assert token_bucket.fill_rate != 0.0 @@ -349,9 +329,35 @@ async def test_rate_can_never_be_zero(self): async def test_refill_caps_at_max_capacity(self): token_bucket = TokenBucket() # Max and current capacity of the bucket is set to 1.0 initially - await token_bucket.update_bucket(10.0) + await token_bucket.update_rate(10.0) async with token_bucket._lock: # type: ignore token_bucket._refill() # type: ignore assert round(token_bucket.current_capacity, 1) == 1.0 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "actions,expected_capacity", + [ + ([("acquire", 1)], 0.0), + ([("acquire", 1), ("update", 4)], 1.0), + ([("acquire", 1), ("update", 4), ("acquire", 1)], 3.0), + ([("acquire", 1), ("update", 4), ("acquire", 1), ("acquire", 1)], 3.0), + ], + ) + async def test_multiple_refills_over_time( + self, actions: list[tuple[str, int]], expected_capacity: float + ): + time_values = [0.0, 1.0, 1.0, 1.5, 1.5, 4.0, 4.0, 5.0] + + with patch("time.monotonic", side_effect=time_values): + token_bucket = TokenBucket(curr_capacity=0, fill_rate=2.0) + + for action, value in actions: + if action == "acquire": + await token_bucket.acquire(value) + elif action == "update": + await token_bucket.update_rate(value) + + assert token_bucket.current_capacity == expected_capacity From dede6fe407583321569baf984c234a220c3a8275 Mon Sep 17 00:00:00 2001 From: ubaskota <19787410+ubaskota@users.noreply.github.com> Date: Thu, 8 Jan 2026 12:16:34 -0500 Subject: [PATCH 5/5] Update multiline docstring for TokenBucket class --- packages/smithy-core/src/smithy_core/retries.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index ec542924..d2318851 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -485,7 +485,8 @@ def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy": class TokenBucket: - """ + """Token bucket for rate limiting with configurable fill rate. + TokenBucket provides a collection of arbitrary tokens while managing issuance and refilling over time. This is controlled by a fill rate that can be variably adjusted. When tokens aren't available, the bucket will enforce a delay before