diff --git a/docs/advance.md b/docs/advance.md index a2d77c5..b03e75a 100644 --- a/docs/advance.md +++ b/docs/advance.md @@ -49,3 +49,27 @@ app = FastAPI() # add your routers .... app.openapi = enrich_openapi(app, ["your_module_name", "some-other-package"]) ``` + +### Finding exceptions in specific API functions of FastAPI + +If you want to limit the search of exceptions to specific functions, you can pass the list of functions to the `find_exceptions_in_api_functions` function. + +```python +from richapi import find_exceptions_in_api_functions, BaseHTTPException +from fastapi import FastAPI + +app = FastAPI() + +class NotFoundException(BaseHTTPException): + status_code = 404 + detail: str = "Item not found" + +@app.get("/items/{item_id}") +async def read_item(item_id: int): + if item_id != 42: + raise NotFoundException() + return {"item_id": item_id} + +exceptions = find_exceptions_in_api_functions(funcs=[read_item], fastapi_app=app) +print(exceptions) # {NotFoundException} +``` diff --git a/richapi/__init__.py b/richapi/__init__.py index 7ab8050..a7b057b 100644 --- a/richapi/__init__.py +++ b/richapi/__init__.py @@ -1,5 +1,9 @@ from richapi.exc_parser.handler import add_exc_handler -from richapi.exc_parser.openapi import enrich_openapi, load_openapi +from richapi.exc_parser.openapi import ( + enrich_openapi, + find_exceptions_in_api_functions, + load_openapi, +) from richapi.exc_parser.protocol import BaseHTTPException, RichHTTPException __all__ = [ @@ -8,4 +12,5 @@ "BaseHTTPException", "RichHTTPException", "add_exc_handler", + "find_exceptions_in_api_functions", ] diff --git a/richapi/exc_parser/openapi.py b/richapi/exc_parser/openapi.py index 5c63510..a8534c6 100644 --- a/richapi/exc_parser/openapi.py +++ b/richapi/exc_parser/openapi.py @@ -37,6 +37,38 @@ def load_openapi( return lambda: openapi_json +def _validate_target_module( + target_module: Union[list[str], str, None], + target_obj: FastAPI | Callable, +) -> list[str]: + if target_module is not None: + return [target_module] if isinstance(target_module, str) else target_module + + if isinstance(target_obj, FastAPI): + target_module = _find_module_name_where_app_defined_in(target_obj) + else: + target_module = _find_module_name_where_func_defined_in(target_obj) + + if target_module is not None: + target_module = target_module.split(".")[0] # get the top-level module + + if target_module is None or target_module == "__main__": + if isinstance(target_obj, FastAPI): + raise BaseRichAPIException( + "Could not determine the module where the FastAPI instance was created.\n" + "Please provide the module name as a string or list of strings.\n" + "Example: enrich_openapi(app, target_module='src')\n" + ) + else: + raise BaseRichAPIException( + "Could not determine the module where the function was created.\n" + "Please provide the module name as a string or list of strings.\n" + "Example: find_exceptions_in_api_functions([func], target_module='src')\n" + ) + + return [target_module] + + def enrich_openapi( app: FastAPI, target_module: Union[list[str], str, None] = None, @@ -47,16 +79,7 @@ def enrich_openapi( routes=app.routes, ), ) -> Callable: - if target_module is None: - target_module = _find_module_name_where_app_defined_in(app) - if target_module is None or target_module == "__main__": - raise BaseRichAPIException( - "Could not determine the module where the FastAPI instance was created.\n" - "Please provide the module name as a string or list of strings.\n" - "Example: enrich_openapi(app, target_module='src')\n" - ) - - target_module = target_module.split(".")[0] # get the top-level module + target_module = _validate_target_module(target_module, app) def _custom_openapi() -> dict: if app.openapi_schema: # pragma: no cover @@ -72,7 +95,7 @@ def _custom_openapi() -> dict: def compile_openapi_from_fastapi( app: FastAPI, - target_module: Union[list[str], str], + target_module: list[str] | str, open_api_getter: Callable[[FastAPI], dict] = lambda app: _get_openapi( title=app.title, version=app.version, @@ -81,15 +104,16 @@ def compile_openapi_from_fastapi( ), ) -> dict: target_module = [target_module] if isinstance(target_module, str) else target_module - target_module.append("fastapi") + openapi_schema = open_api_getter(app) for route in app.routes: if not isinstance(route, APIRoute): continue if route.include_in_schema: - exceptions = _extract_starlette_exceptions(route, target_module) - + exceptions = _extract_starlette_exceptions( + route, target_module + ["fastapi"] + ) _fill_openapi_with_excpetions(openapi_schema, route, exceptions) ExceptionFinder.clear_cache() @@ -97,6 +121,41 @@ def compile_openapi_from_fastapi( return openapi_schema +def _find_fastapi_route_from_callback( + app: FastAPI, callback: Callable +) -> Union[APIRoute, None]: + for route in app.routes: + if not isinstance(route, APIRoute): + continue + if route.endpoint is callback: + return route + + return None + + +def find_exceptions_in_api_functions( + funcs: list[Callable], + fastapi_app: FastAPI, + target_module: str | list[str] | None = None, +) -> set[type[StarletteHTTPException]]: + result: set[type[StarletteHTTPException]] = set() + for func in funcs: + this_func_target_module = _validate_target_module(target_module, func) + this_route = _find_fastapi_route_from_callback(fastapi_app, func) + if this_route is None: + raise BaseRichAPIException( + f"Could not find the route for function {func} in FastAPI app {fastapi_app}" + ) + + exceptions = _extract_starlette_exceptions( + this_route, this_func_target_module + ["fastapi"] + ) + for exc, _ in exceptions: + result.add(exc) + + return result + + def _find_module_name_where_app_defined_in(app: FastAPI) -> Union[str, None]: frame = inspect.currentframe() target_module = None @@ -112,6 +171,21 @@ def _find_module_name_where_app_defined_in(app: FastAPI) -> Union[str, None]: return target_module +def _find_module_name_where_func_defined_in(func: Callable) -> Union[str, None]: + frame = inspect.currentframe() + target_module = None + while frame: + for var_name, var_value in frame.f_globals.items(): + if isinstance(var_value, Callable) and var_value is func: + target_module = frame.f_globals["__name__"] + break + if target_module: + break + frame = frame.f_back + + return target_module + + def _resolve_status_and_detail_from_exc_type( exc_type: type[Exception], ast_raise: ast.Raise, @@ -142,7 +216,7 @@ def _resolve_status_and_detail_from_exc_type( status_code_value.value, (int, str, NoneType), ): - raise ValueError( + raise BaseRichAPIException( f"Status code value must be an integer, string or None, got {type(status_code_value.value)}" ) if status_code_value.value is not None: @@ -166,7 +240,7 @@ def _resolve_status_and_detail_from_exc_type( detail_value = kwarg.value if isinstance(detail_value, ast.Constant): if not isinstance(detail_value.value, (str, NoneType)): - raise ValueError( + raise BaseRichAPIException( f"Detail value must be a string or None, got {type(detail_value.value)}" ) diff --git a/tests/test_router_exception.py b/tests/test_router_exception.py new file mode 100644 index 0000000..6d9bfaa --- /dev/null +++ b/tests/test_router_exception.py @@ -0,0 +1,57 @@ +# from dataclasses import dataclass +# from typing import Literal + +import random +from dataclasses import dataclass + +from fastapi import FastAPI +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from richapi.exc_parser.handler import add_exc_handler +from richapi.exc_parser.openapi import enrich_openapi, find_exceptions_in_api_functions +from richapi.exc_parser.protocol import RichHTTPException + + +@dataclass +class Exception1(RichHTTPException): + status_code = 409 + + +@dataclass +class Exception2(RichHTTPException): + status_code = 408 + + +class NWrapper(BaseModel): + value: int + + +def lol(): + if random.randint(1, 3) == 2: + raise Exception1() + return lol() + + +app = FastAPI() + +app.openapi = enrich_openapi(app, target_module="tests.test_recursive_func") +add_exc_handler(app) + + +@app.get("/{n}", response_model=NWrapper) +async def index(n: int) -> JSONResponse: + lol() + if n % 2 == 0: + raise Exception2() + return JSONResponse(content=jsonable_encoder(NWrapper(value=n))) + + +def test_find_exceptions_in_api_functions(): + result = find_exceptions_in_api_functions( + [index], + target_module="tests.test_router_exception", + fastapi_app=app, + ) + assert result == {Exception1, Exception2}