Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions packages/smithy-core/src/smithy_core/retries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import random
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -480,3 +482,74 @@ def record_success(self, *, token: retries_interface.RetryToken) -> None:

def __deepcopy__(self, memo: Any) -> "StandardRetryStrategy":
return self


class TokenBucket:
MIN_FILL_RATE = 0.5
MIN_CAPACITY = 1.0

def __init__(
self,
*,
curr_capacity: float | None = None,
max_capacity: float | None = None,
fill_rate: float | None = None,
):
self._curr_capacity: float = (
curr_capacity if curr_capacity is not None else self.MIN_CAPACITY
)
self._max_capacity: float = (
max_capacity
if max_capacity is not None
else max(self._curr_capacity, self.MIN_CAPACITY)
)
self._fill_rate: float = (
max(self.MIN_FILL_RATE, fill_rate)
if fill_rate is not None
else self.MIN_FILL_RATE
)
self._last_timestamp: float | None = None
self._lock = asyncio.Lock()

async def acquire(self, amount: float) -> None:
while True:
async with self._lock:
self._refill()
if self._curr_capacity >= amount:
self._curr_capacity -= amount
return

wait_time = (amount - self._curr_capacity) / self._fill_rate
await asyncio.sleep(wait_time)

def _refill(self) -> None:
curr_time = time.monotonic()
if self._last_timestamp is None:
self._last_timestamp = curr_time
return

elapsed = curr_time - self._last_timestamp
refill_amount = elapsed * self._fill_rate
self._curr_capacity = min(
self._max_capacity, self._curr_capacity + refill_amount
)
self._last_timestamp = curr_time

async def update_bucket(self, rate: float) -> None:
async with self._lock:
self._refill()
self._fill_rate = max(rate, self.MIN_FILL_RATE)
self._max_capacity = max(rate, self.MIN_CAPACITY)
self._curr_capacity = min(self._curr_capacity, self._max_capacity)

@property
def current_capacity(self) -> float:
return self._curr_capacity

@property
def max_capacity(self) -> float:
return self._max_capacity

@property
def fill_rate(self) -> float:
return self._fill_rate
113 changes: 113 additions & 0 deletions packages/smithy-core/tests/unit/test_retries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import time
from unittest.mock import AsyncMock, patch

import pytest
from smithy_core.exceptions import CallError, RetryError
from smithy_core.retries import ExponentialBackoffJitterType as EBJT
Expand All @@ -10,6 +14,7 @@
SimpleRetryStrategy,
StandardRetryQuota,
StandardRetryStrategy,
TokenBucket,
)


Expand Down Expand Up @@ -275,3 +280,111 @@ async def test_retry_strategy_resolver_rejects_invalid_type() -> None:
match="retry_strategy must be RetryStrategy, RetryStrategyOptions, or None",
):
await resolver.resolve_retry_strategy(retry_strategy="invalid") # type: ignore


class TestTokenBucket:
@pytest.mark.asyncio
async def test_initial_state(self):
token_bucket = TokenBucket()
assert token_bucket.current_capacity == token_bucket.MIN_CAPACITY
assert token_bucket.max_capacity == token_bucket.MIN_CAPACITY
assert token_bucket.fill_rate == token_bucket.MIN_FILL_RATE

@pytest.mark.asyncio
async def test_acquire_succeeds_immediately_within_capacity(self):
token_bucket = TokenBucket()
start_time = time.monotonic()
await token_bucket.acquire(1)
elapsed = time.monotonic() - start_time

assert elapsed < 0.001 # Should be near instant
assert token_bucket.current_capacity == 0

@pytest.mark.asyncio
async def test_acquire_waits_when_capacity_insufficient(self):
token_bucket = TokenBucket(curr_capacity=0)

with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:

async def side_effect(delay: float):
async with token_bucket._lock: # type: ignore
token_bucket._curr_capacity = 1.0 # type: ignore

mock_sleep.side_effect = side_effect
await token_bucket.acquire(1)
assert mock_sleep.call_count == 1

actual_delay = mock_sleep.call_args[0][0]
assert actual_delay == pytest.approx(2.0, abs=0.05) # type: ignore

@pytest.mark.asyncio
async def test_multiple_refills_over_time(self):
token_bucket = TokenBucket(curr_capacity=0, max_capacity=10, fill_rate=2.0)

time_values = iter([1.0, 1.5, 4.0])
with patch("time.monotonic", side_effect=lambda: next(time_values)):
token_bucket._last_timestamp = 0.0 # type: ignore

async with token_bucket._lock: # type: ignore
token_bucket._refill() # type: ignore
assert token_bucket.current_capacity == pytest.approx(2.0, abs=0.05) # type: ignore

with patch("time.monotonic", side_effect=lambda: next(time_values)):
async with token_bucket._lock: # type: ignore
token_bucket._refill() # type: ignore
assert token_bucket.current_capacity == pytest.approx(3.0, abs=0.05) # type: ignore

with patch("time.monotonic", side_effect=lambda: next(time_values)):
async with token_bucket._lock: # type: ignore
token_bucket._refill() # type: ignore
assert token_bucket.current_capacity == pytest.approx(8.0, abs=0.05) # type: ignore

@pytest.mark.asyncio
async def test_update_bucket_updates_capacity(self):
token_bucket = TokenBucket()

await token_bucket.update_bucket(5.0)
assert token_bucket.fill_rate == 5.0
assert token_bucket.max_capacity == 5.0
assert token_bucket.current_capacity == 1.0

@pytest.mark.asyncio
async def test_rate_can_never_be_zero(self):
token_bucket = TokenBucket()
await token_bucket.update_bucket(0.0)

assert token_bucket.fill_rate != 0.0

@pytest.mark.asyncio
async def test_refill_caps_at_max_capacity(self):
token_bucket = TokenBucket()
# Max and current capacity of the bucket is set to 1.0 initially
await token_bucket.update_bucket(10.0)

async with token_bucket._lock: # type: ignore
token_bucket._refill() # type: ignore
Comment on lines +334 to +335

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we talked about this previously, but we don't want to be calling private methods if we can avoid it. Is there another way we can test this that better reflects the public interfaces?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we did talk about this earlier. I updated the other tests to use a public method, but for this we are specifically testing whether the refill exceeds the max capacity when the rate is increased, so I decided to leave it as it is.


assert token_bucket.current_capacity == pytest.approx(1.0, abs=0.05) # type: ignore

@pytest.mark.asyncio
async def test_many_concurrent_tasks_succeed(self):
token_bucket = TokenBucket(curr_capacity=2.0)
await token_bucket.update_bucket(4.0)
completed_tasks: list[int] = []

async def worker(worker_id: int):
await token_bucket.acquire(0.1)
completed_tasks.append(worker_id)

try:
# At the fill rate of 4/second and acquire cost of 0.1, it should take
# around 2 seconds to process 100 tasks.
await asyncio.wait_for(
asyncio.gather(*[worker(i) for i in range(100)]), timeout=3
)
except TimeoutError:
pytest.fail("Deadlock detected: concurrent acquire operations timed out")

assert len(completed_tasks) == 100
assert len(set(completed_tasks)) == 100
assert token_bucket.current_capacity >= 0.0
Loading