diff --git a/packages/smithy-core/src/smithy_core/aio/client.py b/packages/smithy-core/src/smithy_core/aio/client.py index 6060727b..7824b28a 100644 --- a/packages/smithy-core/src/smithy_core/aio/client.py +++ b/packages/smithy-core/src/smithy_core/aio/client.py @@ -12,7 +12,7 @@ from ..auth import AuthParams from ..deserializers import DeserializeableShape, ShapeDeserializer from ..endpoints import EndpointResolverParams -from ..exceptions import ClientTimeoutError, RetryError, SmithyError +from ..exceptions import CallError, ClientTimeoutError, RetryError, SmithyError from ..interceptors import ( InputContext, Interceptor, @@ -23,6 +23,7 @@ from ..interfaces import Endpoint, TypedProperties from ..interfaces.auth import AuthOption, AuthSchemeResolver from ..interfaces.retries import RetryStrategy +from ..retries import AdaptiveRetryStrategy from ..schemas import APIOperation from ..serializers import SerializeableShape from ..shapes import ShapeID @@ -338,16 +339,47 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape]( if retry_token.retry_delay: await sleep(retry_token.retry_delay) - output_context = await self._handle_attempt( - call, - replace( - request_context, - transport_request=copy(request_context.transport_request), - ), - request_future, - ) + try: + # Rate limiting before request (adaptive only) + await self._handle_pre_request_rate_limiting(retry_strategy) + + output_context = await self._handle_attempt( + call, + replace( + request_context, + transport_request=copy(request_context.transport_request), + ), + request_future, + ) + except TimeoutError as timeout_error: + error = CallError( + fault="client", + message=str(timeout_error), + is_retry_safe=True, # Make it retryable + ) + + # Token acquisition timeout will be treated as retryable error + try: + retry_token = retry_strategy.refresh_retry_token_for_retry( + token_to_renew=retry_token, + error=error, + ) + except RetryError: + raise timeout_error + + _LOGGER.debug( + "Token acquisition timeout. Attempting request #%s in %.4f seconds.", + retry_token.retry_count + 1, + retry_token.retry_delay, + ) + continue # Skip to next retry iteration if isinstance(output_context.response, Exception): + # Update rate limiter after failed response (adaptive only) + await self._handle_post_error_response_rate_limiting( + retry_strategy, output_context.response + ) + try: retry_token = retry_strategy.refresh_retry_token_for_retry( token_to_renew=retry_token, @@ -364,9 +396,35 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape]( await seek(request_context.transport_request.body, 0) else: + # Update rate limiter after successful response (adaptive only) + await self._handle_success_rate_limiting(retry_strategy) retry_strategy.record_success(token=retry_token) return output_context + async def _handle_pre_request_rate_limiting( + self, retry_strategy: RetryStrategy + ) -> None: + """Handle rate limiting before sending request.""" + if isinstance(retry_strategy, AdaptiveRetryStrategy): + await retry_strategy.acquire_from_token_bucket() + + async def _handle_post_error_response_rate_limiting( + self, retry_strategy: RetryStrategy, error: Exception + ) -> None: + """Handle rate limiting after failed response.""" + if isinstance(retry_strategy, AdaptiveRetryStrategy): + is_throttling = retry_strategy.is_throttling_error(error) + await retry_strategy.rate_limiter.after_receiving_response(is_throttling) + + async def _handle_success_rate_limiting( + self, retry_strategy: RetryStrategy + ) -> None: + """Handle rate limiting after successful response.""" + if isinstance(retry_strategy, AdaptiveRetryStrategy): + await retry_strategy.rate_limiter.after_receiving_response( + throttling_error=False + ) + async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape]( self, call: ClientCall[I, O], diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index a185b9ce..477e2767 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -15,7 +15,7 @@ from .interfaces import retries as retries_interface from .interfaces.retries import RetryStrategy -RetryStrategyType = Literal["simple", "standard"] +RetryStrategyType = Literal["simple", "standard", "adaptive"] @dataclass(kw_only=True, frozen=True) @@ -69,6 +69,8 @@ def _create_retry_strategy( return SimpleRetryStrategy(**filtered_kwargs) case "standard": return StandardRetryStrategy(**filtered_kwargs) + case "adaptive": + return AdaptiveRetryStrategy(**filtered_kwargs) case _: raise ValueError(f"Unknown retry mode: {retry_mode}") @@ -820,3 +822,54 @@ async def after_receiving_response(self, throttling_error: bool) -> None: @property def rate_limit_enabled(self) -> bool: return self._rate_limiter_enabled + + +class AdaptiveRetryStrategy(StandardRetryStrategy): + """Adaptive retry strategy with client-side rate limiting using CUBIC algorithm. + + Builds on top of StandardRetryStrategy by adding token bucket rate limiting and + CUBIC congestion control. Rate limiting is enabled after the first throttling + response and dynamically adjusts sending rates based on the response type. + """ + + STARTING_MAX_RATE = 0.5 + + def __init__(self, *, rate_limiter: ClientRateLimiter | None = None, **kwargs): # type: ignore + """Initialize AdaptiveRetryStrategy. + + :param rate_limiter: Optional pre-configured rate limiter. If None, creates + default components with rate limiting initially disabled. + """ + super().__init__(**kwargs) # type: ignore + + if rate_limiter is None: + # Create default rate limiting components + token_bucket = TokenBucket() + cubic_calculator = CubicCalculator( + starting_max_rate=self.STARTING_MAX_RATE, start_time=time.monotonic() + ) + rate_tracker = RequestRateTracker() + self._rate_limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=cubic_calculator, + rate_tracker=rate_tracker, + rate_limiter_enabled=False, # Disabled until first throttle + ) + else: + self._rate_limiter = rate_limiter + + @property + def rate_limiter(self) -> ClientRateLimiter: + """Get the rate limiter for integration with request pipeline.""" + return self._rate_limiter + + def is_throttling_error(self, error: Exception) -> bool: + """Check if error is a throttling error using existing ErrorRetryInfo.""" + if isinstance(error, retries_interface.ErrorRetryInfo): + # Check if ErrorRetryInfo has throttling detection + return getattr(error, "is_throttling_error", False) + return False + + async def acquire_from_token_bucket(self): + if self._rate_limiter.rate_limit_enabled: + await self._rate_limiter.before_sending_request() diff --git a/packages/smithy-core/tests/functional/test_retries.py b/packages/smithy-core/tests/functional/test_retries.py index 9a72b491..9fc46d22 100644 --- a/packages/smithy-core/tests/functional/test_retries.py +++ b/packages/smithy-core/tests/functional/test_retries.py @@ -2,15 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 from asyncio import gather, sleep +from unittest.mock import patch import pytest from smithy_core.exceptions import CallError, ClientTimeoutError, RetryError from smithy_core.interfaces import retries as retries_interface from smithy_core.retries import ( + AdaptiveRetryStrategy, + ClientRateLimiter, + CubicCalculator, ExponentialBackoffJitterType, ExponentialRetryBackoffStrategy, + RequestRateTracker, StandardRetryQuota, StandardRetryStrategy, + TokenBucket, ) @@ -151,3 +157,136 @@ async def test_retry_quota_handles_timeout_errors(): assert result == "success" assert attempts == 3 assert quota.available_capacity == 490 + + +class TestAdaptiveRetryStrategy: + async def retry_operation_with_rate_limiting( + self, + strategy: AdaptiveRetryStrategy, + responses: list[int | Exception], + ) -> tuple[str, int]: + token = strategy.acquire_initial_retry_token() + response_iter = iter(responses) + + while True: + if token.retry_delay: + await sleep(token.retry_delay) + + try: + # Rate limiting step - acquire token from bucket (can raise TimeoutError) + await strategy.acquire_from_token_bucket() + except TimeoutError as timeout_error: + error = CallError( + fault="client", + message=str(timeout_error), + is_retry_safe=True, # Make it retryable + ) + # Timeout should be treated as a retryable error + try: + token = strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + continue # Retry without consuming a response + except RetryError: + raise timeout_error + + response = next(response_iter) + attempt = token.retry_count + 1 + + # Success case + if response == 200: + await strategy.rate_limiter.after_receiving_response( + throttling_error=False + ) + strategy.record_success(token=token) + return "success", attempt + + # Error case - we got a response (even if it's an error) + if isinstance(response, Exception): + error = response + is_throttling = False + else: + error = CallError( + fault="server" if response >= 500 else "client", + message=f"HTTP {response}", + is_retry_safe=response >= 500, + ) + is_throttling = response == 429 + # Update rate limiter after error response + await strategy.rate_limiter.after_receiving_response( + throttling_error=is_throttling + ) + + try: + token = strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + except RetryError: + raise error + + async def test_adaptive_retry_eventually_succeeds(self): + quota = StandardRetryQuota(initial_capacity=500) + strategy = AdaptiveRetryStrategy(max_attempts=3, retry_quota=quota) + + result, attempts = await retry_operation(strategy, [500, 500, 200]) + + assert result == "success" + assert attempts == 3 + assert quota.available_capacity == 495 + + async def test_adaptive_retry_fails_due_to_max_attempts(self): + quota = StandardRetryQuota(initial_capacity=500) + strategy = AdaptiveRetryStrategy(max_attempts=3, retry_quota=quota) + + with pytest.raises(CallError, match="502"): + await retry_operation(strategy, [502, 502, 502]) + + assert quota.available_capacity == 490 + + async def test_adaptive_retry_timeout_counts_as_attempt_main1(self): + """Test that token acquisition timeout counts as a retry attempt and continues retrying.""" + quota = StandardRetryQuota(initial_capacity=500) + + time_counter = [0.0] + + def mock_monotonic(): + # Mock time progression to trigger timeout on first attempt: + # Time values: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 31, 31.1...] + # First attempt: start_time=0.6, timeout_check=31.0 + # Elapsed time: 31.0 - 0.6 = 30.4 seconds > 30s timeout threshold + # Result: TimeoutError raised, counted as attempt 1, then retries continue + + current = time_counter[0] + # Jump to timeout on specific call (e.g., 8th call) + if time_counter[0] == 0.6: # After initial setup + time_counter[0] = 31.0 # Jump to timeout + else: + time_counter[0] += 0.1 + return current + + with ( + patch("time.monotonic", side_effect=mock_monotonic), + patch("asyncio.sleep"), + ): # mock asyncio.sleep while acquiring token + token_bucket = TokenBucket() + rate_limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=CubicCalculator(), + rate_tracker=RequestRateTracker(), + rate_limiter_enabled=True, + ) + + # Drain the initial token + await rate_limiter.before_sending_request() + + strategy = AdaptiveRetryStrategy( + max_attempts=3, retry_quota=quota, rate_limiter=rate_limiter + ) + + result, attempts = await self.retry_operation_with_rate_limiting( + strategy, [500, 200] + ) + + assert result == "success" + assert attempts == 3 + assert quota.available_capacity == 495 diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index 6ae98ffb..b48dafc2 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -5,6 +5,7 @@ import pytest from smithy_core.exceptions import CallError, RetryError from smithy_core.retries import ( + AdaptiveRetryStrategy, ClientRateLimiter, CubicCalculator, ExponentialRetryBackoffStrategy, @@ -449,6 +450,47 @@ async def test_throttling_response_enables_rate_limiter(self): await limiter.after_receiving_response(throttling_error=True) assert limiter.rate_limit_enabled is True + async def test_raises_timeout_error_when_token_acquisition_takes_30_secs(self): + with patch("time.monotonic") as mock_time: + mock_time.side_effect = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 31] + token_bucket = TokenBucket(curr_capacity=0) + calculator = CubicCalculator() + tracker = RequestRateTracker() + limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=True, + ) + + await limiter.before_sending_request() + with pytest.raises(TimeoutError): + await limiter.before_sending_request() + + async def test_tokens_accumulation_after_numerous_successful_responses(self): + with patch("time.monotonic") as mock_time: + mock_time.side_effect = [i * 0.01 for i in range(1000000)] + token_bucket = TokenBucket(curr_capacity=0) + calculator = CubicCalculator() + tracker = RequestRateTracker() + limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=True, + ) + + # First throttled request + await limiter.after_receiving_response(throttling_error=True) + # followed by numerous successful requests + for _ in range(300000): + await limiter.after_receiving_response(throttling_error=False) + # followed by a throttled request + await limiter.after_receiving_response(throttling_error=True) + + # validate that the number of tokens accumulated isn't insanely high + assert token_bucket.current_capacity < 25 + async def test_calculated_rate_is_capped_at_2x_measured_rate(self): with patch("time.monotonic") as mock_time: mock_time.side_effect = [0.0, 0.1, 0.2, 0.3, 0.4] @@ -624,3 +666,275 @@ async def test_measure_rate_new_bucket(self): assert rate > 0 assert tracker.request_count == 0 + + +class TestAdaptiveRetryStrategy: + """Tests for AdaptiveRetryStrategy with rate limiting.""" + + def test_initialization_with_default_rate_limiter(self): + """Test that AdaptiveRetryStrategy creates default rate limiter components.""" + strategy = AdaptiveRetryStrategy() + + assert strategy.rate_limiter is not None + assert isinstance(strategy.rate_limiter, ClientRateLimiter) + assert strategy.rate_limiter.rate_limit_enabled is False # Disabled by default + + def test_initialization_with_custom_rate_limiter(self): + """Test that AdaptiveRetryStrategy accepts custom rate limiter.""" + token_bucket = TokenBucket() + calculator = CubicCalculator(starting_max_rate=5.0, start_time=0.0) + tracker = RequestRateTracker() + custom_limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=True, + ) + + strategy = AdaptiveRetryStrategy(rate_limiter=custom_limiter) + + assert strategy.rate_limiter is custom_limiter + assert strategy.rate_limiter.rate_limit_enabled is True + + def test_is_throttling_error_with_throttling_error(self): + """Test detection of throttling errors.""" + strategy = AdaptiveRetryStrategy() + + # Create an error with throttling flag set + error = CallError(message="Throttled", is_throttling_error=True) + + assert strategy.is_throttling_error(error) is True + + def test_is_throttling_error_with_non_throttling_error(self): + """Test that non-throttling errors return False.""" + strategy = AdaptiveRetryStrategy() + + error = CallError(message="Server error", is_throttling_error=False) + + assert strategy.is_throttling_error(error) is False + + def test_is_throttling_error_with_non_retry_info_error(self): + """Test that errors without ErrorRetryInfo return False.""" + strategy = AdaptiveRetryStrategy() + + error = Exception("Generic error") + + assert strategy.is_throttling_error(error) is False + + @pytest.mark.asyncio + async def test_acquire_from_token_bucket_when_enabled(self): + """Test that acquire is called when rate limiting is enabled.""" + with patch("time.monotonic") as mock_time: + mock_time.return_value = 0.0 + + token_bucket = TokenBucket(curr_capacity=1.0) + calculator = CubicCalculator(starting_max_rate=1.0, start_time=0.0) + tracker = RequestRateTracker() + limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=True, + ) + strategy = AdaptiveRetryStrategy(rate_limiter=limiter) + await strategy.acquire_from_token_bucket() + + # Should have consumed one 1 token + assert token_bucket.current_capacity == 0 + + @pytest.mark.asyncio + async def test_acquire_from_token_bucket_when_disabled(self): + """Test that acquire is not called when rate limiting is disabled.""" + token_bucket = TokenBucket(curr_capacity=1.0) + calculator = CubicCalculator(starting_max_rate=1.0, start_time=0.0) + tracker = RequestRateTracker() + limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=False, + ) + strategy = AdaptiveRetryStrategy(rate_limiter=limiter) + await strategy.acquire_from_token_bucket() + + # Should not have consumed any tokens + assert token_bucket.current_capacity == 1.0 + + @pytest.mark.asyncio + async def test_rate_limiter_enabled_after_throttling(self): + """Test that rate limiter is enabled after first throttling error.""" + with patch("time.monotonic") as mock_time: + mock_time.side_effect = [0.0, 0.1, 0.2, 0.3] + + token_bucket = TokenBucket() + calculator = CubicCalculator(starting_max_rate=1.0, start_time=0.0) + tracker = RequestRateTracker() + limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=False, + ) + strategy = AdaptiveRetryStrategy(rate_limiter=limiter) + + assert strategy.rate_limiter.rate_limit_enabled is False + + # Simulate throttling response + with patch.object(tracker, "measure_rate", return_value=5.0): + await strategy.rate_limiter.after_receiving_response( + throttling_error=True + ) + + assert strategy.rate_limiter.rate_limit_enabled is True + + @pytest.mark.asyncio + async def test_resolver_creates_adaptive_strategy(self): + """Test that RetryStrategyResolver can create AdaptiveRetryStrategy.""" + resolver = RetryStrategyResolver() + option1 = RetryStrategyOptions(retry_mode="adaptive") + + strategy = await resolver.resolve_retry_strategy(retry_strategy=option1) + + assert isinstance(strategy, AdaptiveRetryStrategy) + # default max_attempts for adaptive retries is 3 + assert strategy.max_attempts == 3 + + @pytest.mark.parametrize("max_attempts", [2, 3, 10]) + def test_adaptive_retry_strategy(self, max_attempts: int) -> None: + strategy = AdaptiveRetryStrategy(max_attempts=max_attempts) + error = CallError(is_retry_safe=True) + token = strategy.acquire_initial_retry_token() + for _ in range(max_attempts - 1): + token = strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + with pytest.raises(RetryError): + strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + + @pytest.mark.parametrize( + "error", + [ + Exception(), + CallError(is_retry_safe=None), + CallError(fault="client", is_retry_safe=False), + ], + ids=[ + "unclassified_error", + "safety_unknown_error", + "unsafe_error", + ], + ) + def test_adaptive_retry_does_not_retry_for_non_retryable_errors( + self, error: Exception | CallError + ) -> None: + strategy = AdaptiveRetryStrategy() + token = strategy.acquire_initial_retry_token() + with pytest.raises(RetryError): + strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + + def test_adaptive_retry_after_overrides_backoff(self) -> None: + strategy = AdaptiveRetryStrategy() + error = CallError(is_retry_safe=True, retry_after=5.5) + token = strategy.acquire_initial_retry_token() + token = strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + assert token.retry_delay == 5.5 + + +class TestRequestPipelineRateLimiting: + @pytest.mark.asyncio + async def test_pre_request_rate_limiting_with_adaptive_strategy(self): + """Test that pre-request rate limiting is called for adaptive strategy.""" + with patch("time.monotonic") as mock_time: + mock_time.return_value = 0.0 + + token_bucket = TokenBucket(curr_capacity=1.0) + calculator = CubicCalculator(starting_max_rate=10.0, start_time=0.0) + tracker = RequestRateTracker() + limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=True, + ) + + strategy = AdaptiveRetryStrategy(rate_limiter=limiter) + + # Simulate what RequestPipeline does + if isinstance(strategy, AdaptiveRetryStrategy): # type: ignore[reportUnnecessaryIsInstance] + await strategy.acquire_from_token_bucket() + + # Token should be consumed + assert token_bucket.current_capacity == 0.0 + + @pytest.mark.asyncio + async def test_pre_request_rate_limiting_with_standard_strategy(self): + """Test that pre-request rate limiting is skipped for standard strategy.""" + strategy = StandardRetryStrategy() + + if isinstance(strategy, AdaptiveRetryStrategy): + try: + await strategy.acquire_from_token_bucket() + except Exception as e: + pytest.fail(f"Unexpected exception raised: {e}") + + @pytest.mark.asyncio + async def test_post_error_rate_limiting_with_throttling_error(self): + """Test rate limiter update after throttling error.""" + with patch("time.monotonic") as mock_time: + mock_time.side_effect = [0.0, 0.1, 0.2, 0.3] + + token_bucket = TokenBucket() + calculator = CubicCalculator(starting_max_rate=10.0, start_time=0.0) + tracker = RequestRateTracker() + limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=True, + ) + + strategy = AdaptiveRetryStrategy(rate_limiter=limiter) + error = CallError(message="Throttled", is_throttling_error=True) + + # Simulate what RequestPipeline does + if isinstance(strategy, AdaptiveRetryStrategy): # type: ignore[reportUnnecessaryIsInstance] + is_throttling = strategy.is_throttling_error(error) + with patch.object(tracker, "measure_rate", return_value=5.0): + await strategy.rate_limiter.after_receiving_response(is_throttling) + + # Fill rate should be reduced due to throttling + assert token_bucket.fill_rate < 10.0 + + @pytest.mark.asyncio + async def test_success_rate_limiting_increases_rate(self): + """Test rate limiter update after successful response.""" + with patch("time.monotonic") as mock_time: + mock_time.side_effect = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5] + + token_bucket = TokenBucket() + calculator = CubicCalculator(starting_max_rate=5.0, start_time=0.0) + tracker = RequestRateTracker() + limiter = ClientRateLimiter( + token_bucket=token_bucket, + cubic_calculator=calculator, + rate_tracker=tracker, + rate_limiter_enabled=True, + ) + + strategy = AdaptiveRetryStrategy(rate_limiter=limiter) + + initial_rate = token_bucket.fill_rate + + # Simulate successful responses + with patch.object(tracker, "measure_rate", return_value=3.0): + await strategy.rate_limiter.after_receiving_response( + throttling_error=False + ) + await strategy.rate_limiter.after_receiving_response( + throttling_error=False + ) + + # Fill rate should increase after successful responses + assert token_bucket.fill_rate > initial_rate