diff --git a/.changeset/fix-rate-limiter-abort-signal.md b/.changeset/fix-rate-limiter-abort-signal.md new file mode 100644 index 00000000000..5e06a12563d --- /dev/null +++ b/.changeset/fix-rate-limiter-abort-signal.md @@ -0,0 +1,6 @@ +--- +"@smithy/util-retry": minor +"@smithy/middleware-retry": minor +--- + +feat(util-retry): support AbortSignal in DefaultRateLimiter.getSendToken diff --git a/packages/middleware-retry/src/types.ts b/packages/middleware-retry/src/types.ts index 4f6176d702f..0b89674137f 100644 --- a/packages/middleware-retry/src/types.ts +++ b/packages/middleware-retry/src/types.ts @@ -59,8 +59,10 @@ export interface RateLimiter { * If there is not sufficient capacity, it will either sleep a certain amount * of time until the rate limiter can retrieve a token from its token bucket * or raise an exception indicating there is insufficient capacity. + * + * @param abortSignal - optional signal to abort the token wait early. */ - getSendToken: () => Promise; + getSendToken: (abortSignal?: AbortSignal) => Promise; /** * Updates the client sending rate based on response. diff --git a/packages/util-retry/src/DefaultRateLimiter.spec.ts b/packages/util-retry/src/DefaultRateLimiter.spec.ts index 45c5b34753d..6a8d1ae0e6b 100644 --- a/packages/util-retry/src/DefaultRateLimiter.spec.ts +++ b/packages/util-retry/src/DefaultRateLimiter.spec.ts @@ -40,6 +40,56 @@ describe(DefaultRateLimiter.name, () => { vi.runAllTimers(); expect(spy).toHaveBeenLastCalledWith(expect.any(Function), delay); }); + + it("rejects when abortSignal is already aborted", async () => { + vi.spyOn(Date, "now").mockImplementation(() => 0); + const rateLimiter = new DefaultRateLimiter(); + + vi.mocked(isThrottlingError).mockReturnValueOnce(true); + vi.spyOn(Date, "now").mockImplementation(() => 500); + rateLimiter.updateClientSendingRate({}); + + const abortController = new AbortController(); + const reason = new Error("Lambda timeout approaching"); + abortController.abort(reason); + + await expect(rateLimiter.getSendToken(abortController.signal)).rejects.toBe(reason); + }); + + it("rejects when abortSignal fires during wait", async () => { + vi.spyOn(Date, "now").mockImplementation(() => 0); + const rateLimiter = new DefaultRateLimiter(); + + vi.mocked(isThrottlingError).mockReturnValueOnce(true); + vi.spyOn(Date, "now").mockImplementation(() => 500); + rateLimiter.updateClientSendingRate({}); + + const abortController = new AbortController(); + const reason = new Error("Lambda timeout approaching"); + + const promise = rateLimiter.getSendToken(abortController.signal); + abortController.abort(reason); + + await expect(promise).rejects.toBe(reason); + }); + + it("resolves normally when abortSignal is not aborted", async () => { + vi.spyOn(Date, "now").mockImplementation(() => 0); + const rateLimiter = new DefaultRateLimiter(); + + // Use a spy to immediately resolve the setTimeout callback + vi.spyOn(DefaultRateLimiter as any, "setTimeoutFn").mockImplementation((cb: () => void) => { + cb(); + return 0; + }); + + vi.mocked(isThrottlingError).mockReturnValueOnce(true); + vi.spyOn(Date, "now").mockImplementation(() => 500); + rateLimiter.updateClientSendingRate({}); + + const abortController = new AbortController(); + await expect(rateLimiter.getSendToken(abortController.signal)).resolves.toBeUndefined(); + }); }); describe("cubicSuccess", () => { diff --git a/packages/util-retry/src/DefaultRateLimiter.ts b/packages/util-retry/src/DefaultRateLimiter.ts index dc345d9b1db..e98ed4db42c 100644 --- a/packages/util-retry/src/DefaultRateLimiter.ts +++ b/packages/util-retry/src/DefaultRateLimiter.ts @@ -63,11 +63,11 @@ export class DefaultRateLimiter implements RateLimiter { return Date.now() / 1000; } - public async getSendToken() { - return this.acquireTokenBucket(1); + public async getSendToken(abortSignal?: AbortSignal) { + return this.acquireTokenBucket(1, abortSignal); } - private async acquireTokenBucket(amount: number) { + private async acquireTokenBucket(amount: number, abortSignal?: AbortSignal) { // Client side throttling is not enabled until we see a throttling error. if (!this.enabled) { return; @@ -76,7 +76,22 @@ export class DefaultRateLimiter implements RateLimiter { this.refillTokenBucket(); if (amount > this.currentCapacity) { const delay = ((amount - this.currentCapacity) / this.fillRate) * 1000; - await new Promise((resolve) => DefaultRateLimiter.setTimeoutFn(resolve, delay)); + await new Promise((resolve, reject) => { + const timer = DefaultRateLimiter.setTimeoutFn(resolve, delay); + + if (abortSignal) { + if (abortSignal.aborted) { + clearTimeout(timer); + reject(abortSignal.reason ?? new Error("Request aborted")); + return; + } + const onAbort = () => { + clearTimeout(timer); + reject(abortSignal.reason ?? new Error("Request aborted")); + }; + abortSignal.addEventListener("abort", onAbort, { once: true }); + } + }); } this.currentCapacity = this.currentCapacity - amount; } diff --git a/packages/util-retry/src/types.ts b/packages/util-retry/src/types.ts index f1001e8ddd3..e7744ca4b26 100644 --- a/packages/util-retry/src/types.ts +++ b/packages/util-retry/src/types.ts @@ -7,8 +7,10 @@ export interface RateLimiter { * If there is not sufficient capacity, it will either sleep a certain amount * of time until the rate limiter can retrieve a token from its token bucket * or raise an exception indicating there is insufficient capacity. + * + * @param abortSignal - optional signal to abort the token wait early. */ - getSendToken: () => Promise; + getSendToken: (abortSignal?: AbortSignal) => Promise; /** * Updates the client sending rate based on response.