Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions packages/smithy-core/src/smithy_core/retries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import random
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -480,3 +482,122 @@ def record_success(self, *, token: retries_interface.RetryToken) -> None:

def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy":
return self


class TokenBucket:
"""
Implements the token bucket algorithm for rate limiting with configurable fill
rate and capacity. Tokens are added to the bucket at a fixed rate (fill_rate).
Requests must acquire tokens before proceeding, and will wait if insufficient
tokens are available.

Attributes:
MIN_FILL_RATE: Minimum allowed fill rate (0.5 tokens/second)
MIN_CAPACITY: Minimum allowed bucket capacity (1.0 tokens)
DEFAULT_TIMEOUT: Default timeout for token acquisition (30.0 seconds)
"""

MIN_FILL_RATE = 0.5
MIN_CAPACITY = 1.0
DEFAULT_TIMEOUT = 30.0

def __init__(
self,
*,
curr_capacity: float = MIN_CAPACITY,
fill_rate: float = MIN_FILL_RATE,
timeout: float = DEFAULT_TIMEOUT,
):
"""Initialize a new TokenBucket.

Args:
curr_capacity: Initial number of tokens in the bucket.
fill_rate: Rate at which tokens are added to the bucket (tokens/second).
timeout: Maximum time to wait for token acquisition before
raising TimeoutError.
"""
self._curr_capacity: float = max(curr_capacity, self.MIN_CAPACITY)
self._max_capacity: float = self._curr_capacity
self._fill_rate: float = max(fill_rate, self.MIN_FILL_RATE)
self._timeout = timeout
self._last_timestamp: float = time.monotonic()
self._lock = asyncio.Lock()

async def acquire(self, amount: float) -> None:
"""Acquire tokens from the bucket.

If sufficient tokens are available, they are immediately consumed and the
method returns. If insufficient tokens are available, the method will wait
until enough tokens have been refilled or the timeout is reached.

Args:
amount: Number of tokens to acquire.

Raises:
TimeoutError: If acquisition takes longer than the configured timeout.
"""
start_time = time.monotonic()
while True:
async with self._lock:
self._refill()
if self._curr_capacity >= amount:
self._curr_capacity -= amount
return

elapsed = time.monotonic() - start_time
if elapsed > self._timeout:
raise TimeoutError(
f"Failed to acquire {amount} tokens within {self._timeout}s"
)
wait_time = (amount - self._curr_capacity) / self._fill_rate
await asyncio.sleep(wait_time)

def _refill(self) -> None:
curr_time = time.monotonic()
elapsed = curr_time - self._last_timestamp
refill_amount = elapsed * self._fill_rate
self._curr_capacity = min(
self._max_capacity, self._curr_capacity + refill_amount
)
self._last_timestamp = curr_time

async def update_bucket(self, rate: float) -> None:
"""Update the bucket's fill rate, maximum capacity and current capacity (if its
greater than maximum capacity).

Args:
rate: New fill rate (tokens/second). It won't be less than MIN_FILL_RATE.
Current capacity will be reduced if it exceeds the new maximum capacity.
"""
async with self._lock:
self._refill()
self._fill_rate = max(rate, self.MIN_FILL_RATE)
self._max_capacity = max(rate, self.MIN_CAPACITY)
self._curr_capacity = min(self._curr_capacity, self._max_capacity)

@property
def current_capacity(self) -> float:
"""Get the current number of tokens in the bucket.

Returns:
The current token count as of the last refill operation.
"""
return self._curr_capacity

@property
def max_capacity(self) -> float:
"""Get the maximum capacity of the bucket.

Returns:
The maximum number of tokens the bucket can hold.
"""
return self._max_capacity

@property
def fill_rate(self) -> float:
"""Get the current fill rate of the bucket.

Returns:
The rate at which tokens are added to the bucket (tokens/second).
"""
return self._fill_rate
80 changes: 80 additions & 0 deletions packages/smithy-core/tests/unit/test_retries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch

import pytest
from smithy_core.exceptions import CallError, RetryError
from smithy_core.retries import ExponentialBackoffJitterType as EBJT
Expand All @@ -10,6 +12,7 @@
SimpleRetryStrategy,
StandardRetryQuota,
StandardRetryStrategy,
TokenBucket,
)


Expand Down Expand Up @@ -275,3 +278,80 @@ async def test_retry_strategy_resolver_rejects_invalid_type() -> None:
match="retry_strategy must be RetryStrategy, RetryStrategyOptions, or None",
):
await resolver.resolve_retry_strategy(retry_strategy="invalid") # type: ignore


class TestTokenBucket:
@pytest.mark.asyncio
async def test_initial_state(self):
token_bucket = TokenBucket()
assert token_bucket.current_capacity == token_bucket.MIN_CAPACITY
assert token_bucket.max_capacity == token_bucket.MIN_CAPACITY
assert token_bucket.fill_rate == token_bucket.MIN_FILL_RATE

@pytest.mark.asyncio
async def test_acquire_succeeds_immediately_within_capacity(self):
token_bucket = TokenBucket()

with patch("asyncio.sleep") as mock_sleep:
await token_bucket.acquire(1)
mock_sleep.assert_not_called()

assert token_bucket.current_capacity == 0

@pytest.mark.asyncio
async def test_acquire_waits_when_capacity_insufficient(self):
token_bucket = TokenBucket(fill_rate=1.0)
await token_bucket.acquire(1)

with patch("asyncio.sleep") as mock_sleep:
await token_bucket.acquire(1)
mock_sleep.assert_called()

assert token_bucket.current_capacity == 0.0

@pytest.mark.asyncio
async def test_multiple_refills_over_time(self):
time_values = iter([0.0, 1.0, 1.0, 1.5, 1.5, 4.0, 4.0, 5.0])
with patch("time.monotonic", side_effect=lambda: next(time_values)):
token_bucket = TokenBucket(curr_capacity=0, fill_rate=2.0)
await token_bucket.acquire(1)
assert token_bucket.current_capacity == 0.0

with patch("time.monotonic", side_effect=lambda: next(time_values)):
await token_bucket.update_bucket(4) # Update the rate of refill
assert token_bucket.current_capacity == 1.0

with patch("time.monotonic", side_effect=lambda: next(time_values)):
await token_bucket.acquire(1)
assert token_bucket.current_capacity == 3

with patch("time.monotonic", side_effect=lambda: next(time_values)):
await token_bucket.acquire(1)
assert token_bucket.current_capacity == 3

@pytest.mark.asyncio
async def test_update_bucket_updates_capacity(self):
token_bucket = TokenBucket()

await token_bucket.update_bucket(5.0)
assert token_bucket.fill_rate == 5.0
assert token_bucket.max_capacity == 5.0
assert token_bucket.current_capacity == 1.0

@pytest.mark.asyncio
async def test_rate_can_never_be_zero(self):
token_bucket = TokenBucket()
await token_bucket.update_bucket(0.0)

assert token_bucket.fill_rate != 0.0

@pytest.mark.asyncio
async def test_refill_caps_at_max_capacity(self):
token_bucket = TokenBucket()
# Max and current capacity of the bucket is set to 1.0 initially
await token_bucket.update_bucket(10.0)

async with token_bucket._lock: # type: ignore
token_bucket._refill() # type: ignore
Comment on lines +334 to +335

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we talked about this previously, but we don't want to be calling private methods if we can avoid it. Is there another way we can test this that better reflects the public interfaces?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we did talk about this earlier. I updated the other tests to use a public method, but for this we are specifically testing whether the refill exceeds the max capacity when the rate is increased, so I decided to leave it as it is.


assert round(token_bucket.current_capacity, 1) == 1.0
Loading