Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions packages/smithy-core/src/smithy_core/retries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import random
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -480,3 +482,112 @@ def record_success(self, *, token: retries_interface.RetryToken) -> None:

def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy":
return self


class TokenBucket:
"""
TokenBucket provides a collection of arbitrary tokens while managing issuance
and refilling over time. This is controlled by a fill rate that can be variably
adjusted. When tokens aren't available, the bucket will enforce a delay before
attempting to reacquire tokens until one is available or the defined timeout is
reached.
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - See the multi-line docstring guidance from PEP 257 below:

Multi-line docstrings consist of a summary line just like a one-line docstring, followed by a blank line, followed by a more elaborate description. The summary line may be used by automatic indexing tools; it is important that it fits on one line and is separated from the rest of the docstring by a blank line. The summary line may be on the same line as the opening quotes or on the next line. The entire docstring is indented the same as the quotes at its first line (see example below).

I've provided my suggestion below:

Suggested change
"""
TokenBucket provides a collection of arbitrary tokens while managing issuance
and refilling over time. This is controlled by a fill rate that can be variably
adjusted. When tokens aren't available, the bucket will enforce a delay before
attempting to reacquire tokens until one is available or the defined timeout is
reached.
"""
"""Token bucket for rate limiting with configurable fill rate.
Manages token issuance and automatic refilling over time. When tokens aren't
available, enforces a delay until tokens are refilled or timeout is reached.
"""


MIN_FILL_RATE = 0.5 # Minimum allowed fill rate (0.5 tokens/second)
MIN_CAPACITY = 1.0 # Minimum allowed bucket capacity (1.0 tokens)
DEFAULT_TIMEOUT = 30.0 # Default timeout for token acquisition (30.0 seconds)

def __init__(
self,
*,
curr_capacity: float = MIN_CAPACITY,
fill_rate: float = MIN_FILL_RATE,
timeout: float = DEFAULT_TIMEOUT,
):
"""Initialize a new TokenBucket.

:param curr_capacity: Initial number of tokens in the bucket.
:param fill_rate: Rate at which tokens are added to the bucket (tokens/second).
:param timeout: Maximum time to wait for token acquisition before
raising TimeoutError.
"""
self._curr_capacity: float = max(curr_capacity, self.MIN_CAPACITY)
self._max_capacity: float = self._curr_capacity
self._fill_rate: float = max(fill_rate, self.MIN_FILL_RATE)
self._timeout = timeout
self._last_timestamp: float = time.monotonic()
self._lock = asyncio.Lock()

async def acquire(self, amount: float) -> None:
"""Acquire tokens from the bucket.

If sufficient tokens are available, they are immediately consumed and the
method returns. If insufficient tokens are available, the method will wait
until enough tokens have been refilled or the timeout is reached.

:param amount: Number of tokens to acquire.

:raises TimeoutError: Acquisition took longer than the configured timeout.
"""
start_time = time.monotonic()
while True:
async with self._lock:
self._refill()
if self._curr_capacity >= amount:
self._curr_capacity -= amount
return

elapsed = time.monotonic() - start_time
if elapsed > self._timeout:
# This will be caught in retry strategy and used as part of the retry count
raise TimeoutError(
f"Failed to acquire {amount} tokens within {self._timeout}s"
)
wait_time = (amount - self._curr_capacity) / self._fill_rate
await asyncio.sleep(wait_time)

def _refill(self) -> None:
curr_time = time.monotonic()
elapsed = curr_time - self._last_timestamp
refill_amount = elapsed * self._fill_rate
self._curr_capacity = min(
self._max_capacity, self._curr_capacity + refill_amount
)
self._last_timestamp = curr_time

async def update_rate(self, rate: float) -> None:
"""Update the bucket's fill rate, maximum capacity and current capacity (if its
greater than maximum capacity).

:param rate: New fill rate (tokens/second). It won't be less than MIN_FILL_RATE.
Current capacity will be reduced if it exceeds the new maximum capacity.
"""
async with self._lock:
self._refill()
self._fill_rate = max(rate, self.MIN_FILL_RATE)
self._max_capacity = max(rate, self.MIN_CAPACITY)
self._curr_capacity = min(self._curr_capacity, self._max_capacity)

@property
def current_capacity(self) -> float:
"""Get the current number of tokens in the bucket.

:return: The current token count as of the last refill operation.
"""
return self._curr_capacity

@property
def max_capacity(self) -> float:
"""Get the maximum capacity of the bucket.

:return: The maximum number of tokens the bucket can hold.
"""
return self._max_capacity

@property
def fill_rate(self) -> float:
"""Get the current fill rate of the bucket.

:return: The rate at which tokens are added to the bucket (tokens/second).
"""
return self._fill_rate
86 changes: 86 additions & 0 deletions packages/smithy-core/tests/unit/test_retries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch

import pytest
from smithy_core.exceptions import CallError, RetryError
from smithy_core.retries import ExponentialBackoffJitterType as EBJT
Expand All @@ -10,6 +12,7 @@
SimpleRetryStrategy,
StandardRetryQuota,
StandardRetryStrategy,
TokenBucket,
)


Expand Down Expand Up @@ -275,3 +278,86 @@ async def test_retry_strategy_resolver_rejects_invalid_type() -> None:
match="retry_strategy must be RetryStrategy, RetryStrategyOptions, or None",
):
await resolver.resolve_retry_strategy(retry_strategy="invalid") # type: ignore


class TestTokenBucket:
@pytest.mark.asyncio
async def test_initial_state(self):
token_bucket = TokenBucket()
assert token_bucket.current_capacity == token_bucket.MIN_CAPACITY
assert token_bucket.max_capacity == token_bucket.MIN_CAPACITY
assert token_bucket.fill_rate == token_bucket.MIN_FILL_RATE

@pytest.mark.asyncio
async def test_acquire_succeeds_immediately_within_capacity(self):
token_bucket = TokenBucket()

with patch("asyncio.sleep") as mock_sleep:
await token_bucket.acquire(1)
mock_sleep.assert_not_called()

assert token_bucket.current_capacity == 0

@pytest.mark.asyncio
async def test_acquire_waits_when_capacity_insufficient(self):
token_bucket = TokenBucket(fill_rate=1.0)
await token_bucket.acquire(1)

with patch("asyncio.sleep") as mock_sleep:
await token_bucket.acquire(1)
mock_sleep.assert_called()

assert token_bucket.current_capacity == 0.0

@pytest.mark.asyncio
async def test_update_bucket_updates_rate(self):
token_bucket = TokenBucket()

await token_bucket.update_rate(5.0)
assert token_bucket.fill_rate == 5.0
assert token_bucket.max_capacity == 5.0
assert token_bucket.current_capacity == 1.0

@pytest.mark.asyncio
async def test_rate_can_never_be_zero(self):
token_bucket = TokenBucket()
await token_bucket.update_rate(0.0)

assert token_bucket.fill_rate != 0.0

@pytest.mark.asyncio
async def test_refill_caps_at_max_capacity(self):
token_bucket = TokenBucket()
# Max and current capacity of the bucket is set to 1.0 initially
await token_bucket.update_rate(10.0)

async with token_bucket._lock: # type: ignore
token_bucket._refill() # type: ignore
Comment on lines +334 to +335

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we talked about this previously, but we don't want to be calling private methods if we can avoid it. Is there another way we can test this that better reflects the public interfaces?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we did talk about this earlier. I updated the other tests to use a public method, but for this we are specifically testing whether the refill exceeds the max capacity when the rate is increased, so I decided to leave it as it is.


assert round(token_bucket.current_capacity, 1) == 1.0

@pytest.mark.asyncio
@pytest.mark.parametrize(
"actions,expected_capacity",
[
([("acquire", 1)], 0.0),
([("acquire", 1), ("update", 4)], 1.0),
([("acquire", 1), ("update", 4), ("acquire", 1)], 3.0),
([("acquire", 1), ("update", 4), ("acquire", 1), ("acquire", 1)], 3.0),
],
)
async def test_multiple_refills_over_time(
self, actions: list[tuple[str, int]], expected_capacity: float
):
time_values = [0.0, 1.0, 1.0, 1.5, 1.5, 4.0, 4.0, 5.0]

with patch("time.monotonic", side_effect=time_values):
token_bucket = TokenBucket(curr_capacity=0, fill_rate=2.0)

for action, value in actions:
if action == "acquire":
await token_bucket.acquire(value)
elif action == "update":
await token_bucket.update_rate(value)

assert token_bucket.current_capacity == expected_capacity
Loading