Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/advance.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,27 @@ app = FastAPI()
# add your routers ....
app.openapi = enrich_openapi(app, ["your_module_name", "some-other-package"])
```

### Finding exceptions in specific API functions of FastAPI

If you want to limit the search of exceptions to specific functions, you can pass the list of functions to the `find_exceptions_in_api_functions` function.

```python
from richapi import find_exceptions_in_api_functions, BaseHTTPException
from fastapi import FastAPI

app = FastAPI()

class NotFoundException(BaseHTTPException):
status_code = 404
detail: str = "Item not found"

@app.get("/items/{item_id}")
async def read_item(item_id: int):
if item_id != 42:
raise NotFoundException()
return {"item_id": item_id}

exceptions = find_exceptions_in_api_functions(funcs=[read_item], fastapi_app=app)
print(exceptions) # {NotFoundException}
```
7 changes: 6 additions & 1 deletion richapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from richapi.exc_parser.handler import add_exc_handler
from richapi.exc_parser.openapi import enrich_openapi, load_openapi
from richapi.exc_parser.openapi import (
enrich_openapi,
find_exceptions_in_api_functions,
load_openapi,
)
from richapi.exc_parser.protocol import BaseHTTPException, RichHTTPException

__all__ = [
Expand All @@ -8,4 +12,5 @@
"BaseHTTPException",
"RichHTTPException",
"add_exc_handler",
"find_exceptions_in_api_functions",
]
106 changes: 90 additions & 16 deletions richapi/exc_parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,38 @@ def load_openapi(
return lambda: openapi_json


def _validate_target_module(
target_module: Union[list[str], str, None],
target_obj: FastAPI | Callable,
) -> list[str]:
if target_module is not None:
return [target_module] if isinstance(target_module, str) else target_module

if isinstance(target_obj, FastAPI):
target_module = _find_module_name_where_app_defined_in(target_obj)
else:
target_module = _find_module_name_where_func_defined_in(target_obj)

if target_module is not None:
target_module = target_module.split(".")[0] # get the top-level module

if target_module is None or target_module == "__main__":
if isinstance(target_obj, FastAPI):
raise BaseRichAPIException(
"Could not determine the module where the FastAPI instance was created.\n"
"Please provide the module name as a string or list of strings.\n"
"Example: enrich_openapi(app, target_module='src')\n"
)
else:
raise BaseRichAPIException(
"Could not determine the module where the function was created.\n"
"Please provide the module name as a string or list of strings.\n"
"Example: find_exceptions_in_api_functions([func], target_module='src')\n"
)

return [target_module]


def enrich_openapi(
app: FastAPI,
target_module: Union[list[str], str, None] = None,
Expand All @@ -47,16 +79,7 @@ def enrich_openapi(
routes=app.routes,
),
) -> Callable:
if target_module is None:
target_module = _find_module_name_where_app_defined_in(app)
if target_module is None or target_module == "__main__":
raise BaseRichAPIException(
"Could not determine the module where the FastAPI instance was created.\n"
"Please provide the module name as a string or list of strings.\n"
"Example: enrich_openapi(app, target_module='src')\n"
)

target_module = target_module.split(".")[0] # get the top-level module
target_module = _validate_target_module(target_module, app)

def _custom_openapi() -> dict:
if app.openapi_schema: # pragma: no cover
Expand All @@ -72,7 +95,7 @@ def _custom_openapi() -> dict:

def compile_openapi_from_fastapi(
app: FastAPI,
target_module: Union[list[str], str],
target_module: list[str] | str,
open_api_getter: Callable[[FastAPI], dict] = lambda app: _get_openapi(
title=app.title,
version=app.version,
Expand All @@ -81,22 +104,58 @@ def compile_openapi_from_fastapi(
),
) -> dict:
target_module = [target_module] if isinstance(target_module, str) else target_module
target_module.append("fastapi")

openapi_schema = open_api_getter(app)
for route in app.routes:
if not isinstance(route, APIRoute):
continue

if route.include_in_schema:
exceptions = _extract_starlette_exceptions(route, target_module)

exceptions = _extract_starlette_exceptions(
route, target_module + ["fastapi"]
)
_fill_openapi_with_excpetions(openapi_schema, route, exceptions)

ExceptionFinder.clear_cache()

return openapi_schema


def _find_fastapi_route_from_callback(
app: FastAPI, callback: Callable
) -> Union[APIRoute, None]:
for route in app.routes:
if not isinstance(route, APIRoute):
continue
if route.endpoint is callback:
return route

return None


def find_exceptions_in_api_functions(
funcs: list[Callable],
fastapi_app: FastAPI,
target_module: str | list[str] | None = None,
) -> set[type[StarletteHTTPException]]:
result: set[type[StarletteHTTPException]] = set()
for func in funcs:
this_func_target_module = _validate_target_module(target_module, func)
this_route = _find_fastapi_route_from_callback(fastapi_app, func)
if this_route is None:
raise BaseRichAPIException(
f"Could not find the route for function {func} in FastAPI app {fastapi_app}"
)

exceptions = _extract_starlette_exceptions(
this_route, this_func_target_module + ["fastapi"]
)
for exc, _ in exceptions:
result.add(exc)

return result


def _find_module_name_where_app_defined_in(app: FastAPI) -> Union[str, None]:
frame = inspect.currentframe()
target_module = None
Expand All @@ -112,6 +171,21 @@ def _find_module_name_where_app_defined_in(app: FastAPI) -> Union[str, None]:
return target_module


def _find_module_name_where_func_defined_in(func: Callable) -> Union[str, None]:
frame = inspect.currentframe()
target_module = None
while frame:
for var_name, var_value in frame.f_globals.items():
if isinstance(var_value, Callable) and var_value is func:
target_module = frame.f_globals["__name__"]
break
if target_module:
break
frame = frame.f_back

return target_module


def _resolve_status_and_detail_from_exc_type(
exc_type: type[Exception],
ast_raise: ast.Raise,
Expand Down Expand Up @@ -142,7 +216,7 @@ def _resolve_status_and_detail_from_exc_type(
status_code_value.value,
(int, str, NoneType),
):
raise ValueError(
raise BaseRichAPIException(
f"Status code value must be an integer, string or None, got {type(status_code_value.value)}"
)
if status_code_value.value is not None:
Expand All @@ -166,7 +240,7 @@ def _resolve_status_and_detail_from_exc_type(
detail_value = kwarg.value
if isinstance(detail_value, ast.Constant):
if not isinstance(detail_value.value, (str, NoneType)):
raise ValueError(
raise BaseRichAPIException(
f"Detail value must be a string or None, got {type(detail_value.value)}"
)

Expand Down
57 changes: 57 additions & 0 deletions tests/test_router_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# from dataclasses import dataclass
# from typing import Literal

import random
from dataclasses import dataclass

from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from richapi.exc_parser.handler import add_exc_handler
from richapi.exc_parser.openapi import enrich_openapi, find_exceptions_in_api_functions
from richapi.exc_parser.protocol import RichHTTPException


@dataclass
class Exception1(RichHTTPException):
status_code = 409


@dataclass
class Exception2(RichHTTPException):
status_code = 408


class NWrapper(BaseModel):
value: int


def lol():
if random.randint(1, 3) == 2:
raise Exception1()
return lol()


app = FastAPI()

app.openapi = enrich_openapi(app, target_module="tests.test_recursive_func")
add_exc_handler(app)


@app.get("/{n}", response_model=NWrapper)
async def index(n: int) -> JSONResponse:
lol()
if n % 2 == 0:
raise Exception2()
return JSONResponse(content=jsonable_encoder(NWrapper(value=n)))


def test_find_exceptions_in_api_functions():
result = find_exceptions_in_api_functions(
[index],
target_module="tests.test_router_exception",
fastapi_app=app,
)
assert result == {Exception1, Exception2}
Loading