diff --git a/.gitignore b/.gitignore index be583b176..62fd981cb 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,6 @@ backend/yarn.lock backend/static /process-compose.yml + +.next +next-env.d.ts diff --git a/backend/Pipfile b/backend/Pipfile index f487d024b..1d33be024 100644 --- a/backend/Pipfile +++ b/backend/Pipfile @@ -14,7 +14,7 @@ click = ">=8.1.3" databases = {extras = ["asyncpg"], version = "<=0.8.0"} fastapi = "0.115.6" fastapi-cache2 = ">=0.2.0" -fastapi-sso = ">=0.6.4" +fastapi-sso = "0.17.0" gunicorn = ">=20.1.0" heliclockter = ">=1.3.0" parameterized = ">=0.8.1" diff --git a/backend/bracket/config.py b/backend/bracket/config.py index 38c675260..01f7c4619 100644 --- a/backend/bracket/config.py +++ b/backend/bracket/config.py @@ -8,6 +8,7 @@ from pydantic import Field, PostgresDsn from pydantic_settings import BaseSettings, SettingsConfigDict +from bracket.models.sso import SSOProvider from bracket.utils.types import EnumAutoStr @@ -29,8 +30,8 @@ def get_log_level(self) -> int: class Config(BaseSettings): admin_email: str | None = None admin_password: str | None = None - allow_insecure_http_sso: bool = False allow_user_registration: bool = True + allow_user_basic_login: bool = True allow_demo_user_registration: bool = True captcha_secret: str | None = None base_url: str = "http://localhost:8400" @@ -41,6 +42,12 @@ class Config(BaseSettings): pg_dsn: PostgresDsn = "postgresql://user:pass@localhost:5432/db" # type: ignore[assignment] sentry_dsn: str | None = None + sso_1_provider: SSOProvider | None = None + sso_1_client_id: str | None = None + sso_1_allow_insecure_http_sso: bool = False + sso_1_openid_discovery_url: str | None = None + sso_1_openid_scopes: str | None = None + def is_cors_enabled(self) -> bool: return self.cors_origins != "*" diff --git a/backend/bracket/logic/sso.py b/backend/bracket/logic/sso.py new file mode 100644 index 000000000..c4f01d24c --- /dev/null +++ b/backend/bracket/logic/sso.py @@ -0,0 +1,85 @@ +from typing import Any, assert_never + +import aiohttp +from fastapi_sso import GoogleSSO, OpenID, SSOBase, create_provider +from fastapi_sso.sso.base import DiscoveryDocument +from httpx import AsyncClient + +from bracket.config import config +from bracket.models.sso import SSOID, SSOConfig, SSOProvider + +providers_cache: dict[SSOID, SSOBase] | None = None + + +async def get_discovery_document(discovery_url: str) -> DiscoveryDocument: + async with aiohttp.ClientSession() as session: + response = await session.get(discovery_url) + response.raise_for_status() + response_json = await response.json() + return { + "authorization_endpoint": response_json["authorization_endpoint"], + "token_endpoint": response_json["token_endpoint"], + "userinfo_endpoint": response_json["userinfo_endpoint"], + } + + +async def get_openid_provider(sso_config: SSOConfig) -> type[SSOBase]: + assert sso_config.openid_discovery_url is not None, ( + "`openid_discovery_url` should be set for OpenID SSO" + ) + assert sso_config.openid_scopes is not None, "`openid_scopes` should be set for OpenID SSO" + + def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> OpenID: + return OpenID(display_name=response["sub"]) + + return create_provider( + name="oidc", + discovery_document=await get_discovery_document(sso_config.openid_discovery_url), + response_convertor=convert_openid, + default_scope=sso_config.openid_scopes.split(","), + ) + + +async def build_sso(sso_config: SSOConfig) -> SSOBase: + match sso_config.provider: + case SSOProvider.google: + provider: type[SSOBase] = GoogleSSO + case SSOProvider.github: + provider = GoogleSSO + case SSOProvider.openid: + provider = await get_openid_provider(sso_config) + case _ as fallback: + assert_never(fallback) + + return provider( + client_id=sso_config.client_id, + client_secret=sso_config.client_secret, + redirect_uri=sso_config.redirect_uri, + allow_insecure_http=sso_config.allow_insecure_http, + ) + + +async def get_sso_providers() -> dict[SSOID, SSOBase]: + global providers_cache # noqa: PLW0603 + + if providers_cache is not None: + return providers_cache + + configs = [] + if ( + config.sso_1_provider is not None + and config.sso_1_client_id is not None + ): + configs.append( + SSOConfig( + id=SSOID(1), + provider=config.sso_1_provider, + client_id=config.sso_1_client_id, + redirect_uri=f"{config.base_url}/sso-callback/1", + allow_insecure_http=config.sso_1_allow_insecure_http_sso, + ) + ) + + providers = {sso_config.id: await build_sso(sso_config) for sso_config in configs} + providers_cache = providers + return providers_cache diff --git a/backend/bracket/models/sso.py b/backend/bracket/models/sso.py new file mode 100644 index 000000000..844409418 --- /dev/null +++ b/backend/bracket/models/sso.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from enum import StrEnum, auto +from typing import NewType + +from pydantic import BaseModel + +SSOID = NewType("SSOID", int) + + +class SSOConfig(BaseModel): + id: SSOID + provider: SSOProvider + client_id: str + client_secret: str + redirect_uri: str + allow_insecure_http: bool + openid_discovery_url: str | None = None + openid_scopes: str | None = None + + +class SSOProvider(StrEnum): + google = auto() + github = auto() + openid = auto() diff --git a/backend/bracket/routes/auth.py b/backend/bracket/routes/auth.py index cbb0ebe67..234424cc9 100644 --- a/backend/bracket/routes/auth.py +++ b/backend/bracket/routes/auth.py @@ -7,11 +7,14 @@ from jwt import DecodeError, ExpiredSignatureError from pydantic import BaseModel from starlette.requests import Request +from starlette.responses import RedirectResponse from bracket.config import config from bracket.database import database +from bracket.logic.sso import get_sso_providers from bracket.models.db.tournament import Tournament from bracket.models.db.user import UserInDB, UserPublic +from bracket.models.sso import SSOID from bracket.schema import tournaments from bracket.sql.tournaments import sql_get_tournament_by_endpoint_name from bracket.sql.users import get_user, get_user_access_to_club, get_user_access_to_tournament @@ -28,20 +31,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -# def convert_openid(response: dict[str, Any]) -> OpenID: -# """Convert user information returned by OIDC""" -# return OpenID(display_name=response["sub"]) - - # os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" -# sso = GoogleSSO( -# client_id="test", -# client_secret="secret", -# redirect_uri="http://localhost:8080/sso_callback", -# allow_insecure_http=config.allow_insecure_http_sso, -# ) - class Token(BaseModel): access_token: str @@ -184,22 +175,25 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( return Token(access_token=access_token, token_type="bearer", user_id=user.id) -# @router.get("/login", summary='SSO login') -# async def sso_login() -> RedirectResponse: -# """Generate login url and redirect""" -# return cast(RedirectResponse, await sso.get_login_redirect()) -# -# -# @router.get("/sso_callback", summary='SSO callback') -# async def sso_callback(request: Request) -> dict[str, Any]: -# """Process login response from OIDC and return user info""" -# user = await sso.verify_and_process(request) -# if user is None: -# raise HTTPException(401, "Failed to fetch user information") -# return { -# "id": user.id, -# "picture": user.picture, -# "display_name": user.display_name, -# "email": user.email, -# "provider": user.provider, -# } +@router.get("/sso-login/{sso_id}") +async def sso_login(sso_id: SSOID) -> RedirectResponse: + """Generate login url and redirect""" + sso_providers = await get_sso_providers() + return await sso_providers[sso_id].get_login_redirect() + + +@router.get("/sso-callback/{sso_id}") +async def sso_callback(request: Request, sso_id: SSOID) -> dict[str, Any]: + """Process login response from OIDC and return user info""" + sso_providers = await get_sso_providers() + user = await sso_providers[sso_id].verify_and_process(request) + if user is None: + raise HTTPException(401, "Failed to fetch user information") + + return { + "id": user.id, + "picture": user.picture, + "display_name": user.display_name, + "email": user.email, + "provider": user.provider, + }