22
33import time
44from collections import defaultdict
5- from collections .abc import Callable
5+ from collections .abc import Awaitable , Callable
66from dataclasses import dataclass , field
7+ from typing import TYPE_CHECKING , cast
78
89from fastapi import HTTPException , Request , status
10+ from fastapi .responses import Response
911
1012from src .lib .config import settings
1113from src .lib .logging import get_logger
1214
15+ if TYPE_CHECKING :
16+ import redis .asyncio as redis_module
17+
1318logger = get_logger (__name__ )
1419
1520
@@ -74,14 +79,17 @@ class RedisRateLimiter:
7479 def __init__ (self , requests : int , window : int ):
7580 self .requests = requests
7681 self .window = window
77- self ._redis = None
82+ self ._redis : "redis_module.Redis | None" = None
7883
79- async def _get_redis (self ):
84+ async def _get_redis (self ) -> "redis_module.Redis" :
8085 """Lazy Redis connection."""
8186 if self ._redis is None :
8287 import redis .asyncio as redis
8388
84- self ._redis = redis .from_url (settings .REDIS_URL )
89+ self ._redis = cast (
90+ redis_module .Redis ,
91+ redis .from_url (settings .REDIS_URL ), # type: ignore[no-untyped-call]
92+ )
8593 return self ._redis
8694
8795 async def is_allowed (self , key : str ) -> tuple [bool , int , int ]:
@@ -114,7 +122,7 @@ async def is_allowed(self, key: str) -> tuple[bool, int, int]:
114122
115123 return True , remaining , reset_after
116124
117- async def close (self ):
125+ async def close (self ) -> None :
118126 """Close Redis connection."""
119127 if self ._redis :
120128 await self ._redis .aclose ()
@@ -144,7 +152,7 @@ def rate_limit(
144152 requests : int = 100 ,
145153 window : int = 60 ,
146154 key_func : Callable [[Request ], str ] | None = None ,
147- ):
155+ ) -> Callable [[ Callable [..., Awaitable [ Response ]]], Callable [..., Awaitable [ Response ]]] :
148156 """
149157 Rate limit decorator for FastAPI endpoints.
150158
@@ -162,16 +170,18 @@ async def get_resource():
162170 config = RateLimitConfig (requests = requests , window = window , key_func = key_func )
163171 actual_key_func = key_func or default_key_func
164172
165- def decorator (func ):
166- async def wrapper (* args , ** kwargs ):
173+ def decorator (
174+ func : Callable [..., Awaitable [Response ]],
175+ ) -> Callable [..., Awaitable [Response ]]:
176+ async def wrapper (* args : object , ** kwargs : object ) -> Response :
167177 # Find request in args/kwargs
168178 request : Request | None = None
169179 for arg in args :
170180 if isinstance (arg , Request ):
171181 request = arg
172182 break
173183 if request is None :
174- request = kwargs .get ("request" )
184+ request = cast ( Request | None , kwargs .get ("request" ) )
175185
176186 if request is None :
177187 return await func (* args , ** kwargs )
@@ -208,8 +218,10 @@ async def wrapper(*args, **kwargs):
208218
209219
210220async def rate_limit_middleware (
211- request : Request , call_next , config : RateLimitConfig | None = None
212- ):
221+ request : Request ,
222+ call_next : Callable [[Request ], Awaitable [Response ]],
223+ config : RateLimitConfig | None = None ,
224+ ) -> Response :
213225 """
214226 Rate limit middleware for global application.
215227
0 commit comments