Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions packages/smithy-core/src/smithy_core/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..auth import AuthParams
from ..deserializers import DeserializeableShape, ShapeDeserializer
from ..endpoints import EndpointResolverParams
from ..exceptions import ClientTimeoutError, RetryError, SmithyError
from ..exceptions import CallError, ClientTimeoutError, RetryError, SmithyError
from ..interceptors import (
InputContext,
Interceptor,
Expand All @@ -23,6 +23,7 @@
from ..interfaces import Endpoint, TypedProperties
from ..interfaces.auth import AuthOption, AuthSchemeResolver
from ..interfaces.retries import RetryStrategy
from ..retries import AdaptiveRetryStrategy
from ..schemas import APIOperation
from ..serializers import SerializeableShape
from ..shapes import ShapeID
Expand Down Expand Up @@ -338,16 +339,47 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
if retry_token.retry_delay:
await sleep(retry_token.retry_delay)

output_context = await self._handle_attempt(
call,
replace(
request_context,
transport_request=copy(request_context.transport_request),
),
request_future,
)
try:
# Rate limiting before request (adaptive only)
await self._handle_pre_request_rate_limiting(retry_strategy)

output_context = await self._handle_attempt(
call,
replace(
request_context,
transport_request=copy(request_context.transport_request),
),
request_future,
)
except TimeoutError as timeout_error:
error = CallError(
fault="client",
message=str(timeout_error),
is_retry_safe=True, # Make it retryable
)

# Token acquisition timeout will be treated as retryable error
try:
retry_token = retry_strategy.refresh_retry_token_for_retry(
token_to_renew=retry_token,
error=error,
)
except RetryError:
raise timeout_error

_LOGGER.debug(
"Token acquisition timeout. Attempting request #%s in %.4f seconds.",
retry_token.retry_count + 1,
retry_token.retry_delay,
)
continue # Skip to next retry iteration

if isinstance(output_context.response, Exception):
# Update rate limiter after failed response (adaptive only)
await self._handle_post_error_response_rate_limiting(
retry_strategy, output_context.response
)

try:
retry_token = retry_strategy.refresh_retry_token_for_retry(
token_to_renew=retry_token,
Expand All @@ -364,9 +396,35 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](

await seek(request_context.transport_request.body, 0)
else:
# Update rate limiter after successful response (adaptive only)
await self._handle_success_rate_limiting(retry_strategy)
retry_strategy.record_success(token=retry_token)
return output_context

async def _handle_pre_request_rate_limiting(
self, retry_strategy: RetryStrategy
) -> None:
"""Handle rate limiting before sending request."""
if isinstance(retry_strategy, AdaptiveRetryStrategy):
await retry_strategy.acquire_from_token_bucket()

async def _handle_post_error_response_rate_limiting(
self, retry_strategy: RetryStrategy, error: Exception
) -> None:
"""Handle rate limiting after failed response."""
if isinstance(retry_strategy, AdaptiveRetryStrategy):
is_throttling = retry_strategy.is_throttling_error(error)
await retry_strategy.rate_limiter.after_receiving_response(is_throttling)

async def _handle_success_rate_limiting(
self, retry_strategy: RetryStrategy
) -> None:
"""Handle rate limiting after successful response."""
if isinstance(retry_strategy, AdaptiveRetryStrategy):
await retry_strategy.rate_limiter.after_receiving_response(
throttling_error=False
)

async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](
self,
call: ClientCall[I, O],
Expand Down
55 changes: 54 additions & 1 deletion packages/smithy-core/src/smithy_core/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .interfaces import retries as retries_interface
from .interfaces.retries import RetryStrategy

RetryStrategyType = Literal["simple", "standard"]
RetryStrategyType = Literal["simple", "standard", "adaptive"]


@dataclass(kw_only=True, frozen=True)
Expand Down Expand Up @@ -69,6 +69,8 @@ def _create_retry_strategy(
return SimpleRetryStrategy(**filtered_kwargs)
case "standard":
return StandardRetryStrategy(**filtered_kwargs)
case "adaptive":
return AdaptiveRetryStrategy(**filtered_kwargs)
case _:
raise ValueError(f"Unknown retry mode: {retry_mode}")

Expand Down Expand Up @@ -820,3 +822,54 @@ async def after_receiving_response(self, throttling_error: bool) -> None:
@property
def rate_limit_enabled(self) -> bool:
return self._rate_limiter_enabled


class AdaptiveRetryStrategy(StandardRetryStrategy):
"""Adaptive retry strategy with client-side rate limiting using CUBIC algorithm.

Builds on top of StandardRetryStrategy by adding token bucket rate limiting and
CUBIC congestion control. Rate limiting is enabled after the first throttling
response and dynamically adjusts sending rates based on the response type.
"""

STARTING_MAX_RATE = 0.5

def __init__(self, *, rate_limiter: ClientRateLimiter | None = None, **kwargs): # type: ignore
"""Initialize AdaptiveRetryStrategy.

:param rate_limiter: Optional pre-configured rate limiter. If None, creates
default components with rate limiting initially disabled.
"""
super().__init__(**kwargs) # type: ignore

if rate_limiter is None:
# Create default rate limiting components
token_bucket = TokenBucket()
cubic_calculator = CubicCalculator(
starting_max_rate=self.STARTING_MAX_RATE, start_time=time.monotonic()
)
rate_tracker = RequestRateTracker()
self._rate_limiter = ClientRateLimiter(
token_bucket=token_bucket,
cubic_calculator=cubic_calculator,
rate_tracker=rate_tracker,
rate_limiter_enabled=False, # Disabled until first throttle
)
else:
self._rate_limiter = rate_limiter

@property
def rate_limiter(self) -> ClientRateLimiter:
"""Get the rate limiter for integration with request pipeline."""
return self._rate_limiter

def is_throttling_error(self, error: Exception) -> bool:
"""Check if error is a throttling error using existing ErrorRetryInfo."""
if isinstance(error, retries_interface.ErrorRetryInfo):
# Check if ErrorRetryInfo has throttling detection
return getattr(error, "is_throttling_error", False)
return False

async def acquire_from_token_bucket(self):
if self._rate_limiter.rate_limit_enabled:
await self._rate_limiter.before_sending_request()
139 changes: 139 additions & 0 deletions packages/smithy-core/tests/functional/test_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
# SPDX-License-Identifier: Apache-2.0

from asyncio import gather, sleep
from unittest.mock import patch

import pytest
from smithy_core.exceptions import CallError, ClientTimeoutError, RetryError
from smithy_core.interfaces import retries as retries_interface
from smithy_core.retries import (
AdaptiveRetryStrategy,
ClientRateLimiter,
CubicCalculator,
ExponentialBackoffJitterType,
ExponentialRetryBackoffStrategy,
RequestRateTracker,
StandardRetryQuota,
StandardRetryStrategy,
TokenBucket,
)


Expand Down Expand Up @@ -151,3 +157,136 @@ async def test_retry_quota_handles_timeout_errors():
assert result == "success"
assert attempts == 3
assert quota.available_capacity == 490


class TestAdaptiveRetryStrategy:
async def retry_operation_with_rate_limiting(
self,
strategy: AdaptiveRetryStrategy,
responses: list[int | Exception],
) -> tuple[str, int]:
token = strategy.acquire_initial_retry_token()
response_iter = iter(responses)

while True:
if token.retry_delay:
await sleep(token.retry_delay)

try:
# Rate limiting step - acquire token from bucket (can raise TimeoutError)
await strategy.acquire_from_token_bucket()
except TimeoutError as timeout_error:
error = CallError(
fault="client",
message=str(timeout_error),
is_retry_safe=True, # Make it retryable
)
# Timeout should be treated as a retryable error
try:
token = strategy.refresh_retry_token_for_retry(
token_to_renew=token, error=error
)
continue # Retry without consuming a response
except RetryError:
raise timeout_error

response = next(response_iter)
attempt = token.retry_count + 1

# Success case
if response == 200:
await strategy.rate_limiter.after_receiving_response(
throttling_error=False
)
strategy.record_success(token=token)
return "success", attempt

# Error case - we got a response (even if it's an error)
if isinstance(response, Exception):
error = response
is_throttling = False
else:
error = CallError(
fault="server" if response >= 500 else "client",
message=f"HTTP {response}",
is_retry_safe=response >= 500,
)
is_throttling = response == 429
# Update rate limiter after error response
await strategy.rate_limiter.after_receiving_response(
throttling_error=is_throttling
)

try:
token = strategy.refresh_retry_token_for_retry(
token_to_renew=token, error=error
)
except RetryError:
raise error

async def test_adaptive_retry_eventually_succeeds(self):
quota = StandardRetryQuota(initial_capacity=500)
strategy = AdaptiveRetryStrategy(max_attempts=3, retry_quota=quota)

result, attempts = await retry_operation(strategy, [500, 500, 200])

assert result == "success"
assert attempts == 3
assert quota.available_capacity == 495

async def test_adaptive_retry_fails_due_to_max_attempts(self):
quota = StandardRetryQuota(initial_capacity=500)
strategy = AdaptiveRetryStrategy(max_attempts=3, retry_quota=quota)

with pytest.raises(CallError, match="502"):
await retry_operation(strategy, [502, 502, 502])

assert quota.available_capacity == 490

async def test_adaptive_retry_timeout_counts_as_attempt_main1(self):
"""Test that token acquisition timeout counts as a retry attempt and continues retrying."""
quota = StandardRetryQuota(initial_capacity=500)

time_counter = [0.0]

def mock_monotonic():
# Mock time progression to trigger timeout on first attempt:
# Time values: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 31, 31.1...]
# First attempt: start_time=0.6, timeout_check=31.0
# Elapsed time: 31.0 - 0.6 = 30.4 seconds > 30s timeout threshold
# Result: TimeoutError raised, counted as attempt 1, then retries continue

current = time_counter[0]
# Jump to timeout on specific call (e.g., 8th call)
if time_counter[0] == 0.6: # After initial setup
time_counter[0] = 31.0 # Jump to timeout
else:
time_counter[0] += 0.1
return current

with (
patch("time.monotonic", side_effect=mock_monotonic),
patch("asyncio.sleep"),
): # mock asyncio.sleep while acquiring token
token_bucket = TokenBucket()
rate_limiter = ClientRateLimiter(
token_bucket=token_bucket,
cubic_calculator=CubicCalculator(),
rate_tracker=RequestRateTracker(),
rate_limiter_enabled=True,
)

# Drain the initial token
await rate_limiter.before_sending_request()

strategy = AdaptiveRetryStrategy(
max_attempts=3, retry_quota=quota, rate_limiter=rate_limiter
)

result, attempts = await self.retry_operation_with_rate_limiting(
strategy, [500, 200]
)

assert result == "success"
assert attempts == 3
assert quota.available_capacity == 495
Loading
Loading