From 0b106e6f3d9f6e46d6d0877917b59c3f031fc121 Mon Sep 17 00:00:00 2001 From: Erik Vroon Date: Sun, 23 Feb 2025 13:14:05 +0100 Subject: [PATCH 1/5] Implement sso --- backend/Pipfile | 2 +- backend/bracket/config.py | 10 +++- backend/bracket/logic/sso.py | 86 ++++++++++++++++++++++++++++++++++ backend/bracket/models/sso.py | 25 ++++++++++ backend/bracket/routes/auth.py | 57 ++++++++++------------ 5 files changed, 147 insertions(+), 33 deletions(-) create mode 100644 backend/bracket/logic/sso.py create mode 100644 backend/bracket/models/sso.py diff --git a/backend/Pipfile b/backend/Pipfile index f487d024b..1d33be024 100644 --- a/backend/Pipfile +++ b/backend/Pipfile @@ -14,7 +14,7 @@ click = ">=8.1.3" databases = {extras = ["asyncpg"], version = "<=0.8.0"} fastapi = "0.115.6" fastapi-cache2 = ">=0.2.0" -fastapi-sso = ">=0.6.4" +fastapi-sso = "0.17.0" gunicorn = ">=20.1.0" heliclockter = ">=1.3.0" parameterized = ">=0.8.1" diff --git a/backend/bracket/config.py b/backend/bracket/config.py index 38c675260..7064898b0 100644 --- a/backend/bracket/config.py +++ b/backend/bracket/config.py @@ -8,6 +8,7 @@ from pydantic import Field, PostgresDsn from pydantic_settings import BaseSettings, SettingsConfigDict +from bracket.models.sso import SSOProvider from bracket.utils.types import EnumAutoStr @@ -29,8 +30,8 @@ def get_log_level(self) -> int: class Config(BaseSettings): admin_email: str | None = None admin_password: str | None = None - allow_insecure_http_sso: bool = False allow_user_registration: bool = True + allow_user_basic_login: bool = True allow_demo_user_registration: bool = True captcha_secret: str | None = None base_url: str = "http://localhost:8400" @@ -41,6 +42,13 @@ class Config(BaseSettings): pg_dsn: PostgresDsn = "postgresql://user:pass@localhost:5432/db" # type: ignore[assignment] sentry_dsn: str | None = None + sso_1_provider: SSOProvider | None = None + sso_1_client_id: str | None = None + sso_1_client_secret: str | None = None + sso_1_allow_insecure_http_sso: bool = False + sso_1_openid_discovery_url: str | None = None + sso_1_openid_scopes: str | None = None + def is_cors_enabled(self) -> bool: return self.cors_origins != "*" diff --git a/backend/bracket/logic/sso.py b/backend/bracket/logic/sso.py new file mode 100644 index 000000000..b5004d2aa --- /dev/null +++ b/backend/bracket/logic/sso.py @@ -0,0 +1,86 @@ +from functools import cache +from typing import Any + +import aiohttp +from fastapi_sso import GithubSSO, GoogleSSO, OpenID, SSOBase, create_provider +from fastapi_sso.sso.base import DiscoveryDocument +from httpx import AsyncClient + +from bracket.config import config +from bracket.models.sso import SSOID, SSOConfig, SSOProvider + + +async def get_discovery_document(discovery_url: str) -> DiscoveryDocument: + async with aiohttp.ClientSession() as session: + response = await session.get(discovery_url) + response.raise_for_status() + response_json = await response.json() + return { + "authorization_endpoint": response_json["authorization_endpoint"], + "token_endpoint": response_json["token_endpoint"], + "userinfo_endpoint": response_json["userinfo_endpoint"], + } + + +async def build_sso(sso_config: SSOConfig) -> SSOBase: + match sso_config.provider: + case SSOProvider.google: + return GoogleSSO( + client_id=sso_config.client_id, + client_secret=sso_config.client_secret, + redirect_uri=sso_config.redirect_uri, + allow_insecure_http=sso_config.allow_insecure_http, + ) + case SSOProvider.github: + return GithubSSO( + client_id=sso_config.client_id, + client_secret=sso_config.client_secret, + redirect_uri=sso_config.redirect_uri, + allow_insecure_http=sso_config.allow_insecure_http, + ) + case SSOProvider.openid: + assert sso_config.openid_discovery_url is not None, ( + "`openid_discovery_url` should be set for OpenID SSO" + ) + assert sso_config.openid_scopes is not None, ( + "`openid_scopes` should be set for OpenID SSO" + ) + + def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> OpenID: + return OpenID(display_name=response["sub"]) + + GenericSSO = create_provider( + name="oidc", + discovery_document=await get_discovery_document(sso_config.openid_discovery_url), + response_convertor=convert_openid, + ) + + return GenericSSO( + client_id=sso_config.client_id, + client_secret=sso_config.client_secret, + redirect_uri=sso_config.redirect_uri, + allow_insecure_http=sso_config.allow_insecure_http, + scope=sso_config.openid_scopes.split(","), + ) + + +@cache +async def get_sso_providers() -> dict[SSOID, SSOBase]: + configs = [] + if ( + config.sso_1_provider is not None + and config.sso_1_client_id is not None + and config.sso_1_client_secret is not None + ): + configs.append( + SSOConfig( + id=SSOID(1), + provider=config.sso_1_provider, + client_id=config.sso_1_client_id, + client_secret=config.sso_1_client_secret, + redirect_uri=f"{config.base_url}/sso-callback", + allow_insecure_http=config.sso_1_allow_insecure_http_sso, + ) + ) + + return {sso_config.id: await build_sso(sso_config) for sso_config in configs} diff --git a/backend/bracket/models/sso.py b/backend/bracket/models/sso.py new file mode 100644 index 000000000..844409418 --- /dev/null +++ b/backend/bracket/models/sso.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from enum import StrEnum, auto +from typing import NewType + +from pydantic import BaseModel + +SSOID = NewType("SSOID", int) + + +class SSOConfig(BaseModel): + id: SSOID + provider: SSOProvider + client_id: str + client_secret: str + redirect_uri: str + allow_insecure_http: bool + openid_discovery_url: str | None = None + openid_scopes: str | None = None + + +class SSOProvider(StrEnum): + google = auto() + github = auto() + openid = auto() diff --git a/backend/bracket/routes/auth.py b/backend/bracket/routes/auth.py index cbb0ebe67..2d2c748fd 100644 --- a/backend/bracket/routes/auth.py +++ b/backend/bracket/routes/auth.py @@ -1,3 +1,4 @@ +import os from typing import Any import jwt @@ -7,11 +8,14 @@ from jwt import DecodeError, ExpiredSignatureError from pydantic import BaseModel from starlette.requests import Request +from starlette.responses import RedirectResponse from bracket.config import config from bracket.database import database +from bracket.logic.sso import get_sso_providers from bracket.models.db.tournament import Tournament from bracket.models.db.user import UserInDB, UserPublic +from bracket.models.sso import SSOID from bracket.schema import tournaments from bracket.sql.tournaments import sql_get_tournament_by_endpoint_name from bracket.sql.users import get_user, get_user_access_to_club, get_user_access_to_tournament @@ -28,20 +32,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -# def convert_openid(response: dict[str, Any]) -> OpenID: -# """Convert user information returned by OIDC""" -# return OpenID(display_name=response["sub"]) - - # os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" -# sso = GoogleSSO( -# client_id="test", -# client_secret="secret", -# redirect_uri="http://localhost:8080/sso_callback", -# allow_insecure_http=config.allow_insecure_http_sso, -# ) - class Token(BaseModel): access_token: str @@ -184,22 +176,25 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( return Token(access_token=access_token, token_type="bearer", user_id=user.id) -# @router.get("/login", summary='SSO login') -# async def sso_login() -> RedirectResponse: -# """Generate login url and redirect""" -# return cast(RedirectResponse, await sso.get_login_redirect()) -# -# -# @router.get("/sso_callback", summary='SSO callback') -# async def sso_callback(request: Request) -> dict[str, Any]: -# """Process login response from OIDC and return user info""" -# user = await sso.verify_and_process(request) -# if user is None: -# raise HTTPException(401, "Failed to fetch user information") -# return { -# "id": user.id, -# "picture": user.picture, -# "display_name": user.display_name, -# "email": user.email, -# "provider": user.provider, -# } +@router.get("/sso-login/{sso_id}") +async def sso_login(sso_id: SSOID) -> RedirectResponse: + """Generate login url and redirect""" + sso_providers = await get_sso_providers() + return await sso_providers[sso_id].get_login_redirect() + + +@router.get("/sso-callback/{sso_id}") +async def sso_callback(request: Request, sso_id: SSOID) -> dict[str, Any]: + """Process login response from OIDC and return user info""" + sso_providers = await get_sso_providers() + user = await sso_providers[sso_id].verify_and_process(request) + if user is None: + raise HTTPException(401, "Failed to fetch user information") + + return { + "id": user.id, + "picture": user.picture, + "display_name": user.display_name, + "email": user.email, + "provider": user.provider, + } From 7d0637c8e240ae96e69fe13a79cfaa81323e3fee Mon Sep 17 00:00:00 2001 From: Erik Vroon Date: Sun, 23 Feb 2025 13:15:58 +0100 Subject: [PATCH 2/5] fixup! Implement sso --- backend/bracket/logic/sso.py | 48 ++++++++++++++++++---------------- backend/bracket/routes/auth.py | 1 - 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/backend/bracket/logic/sso.py b/backend/bracket/logic/sso.py index b5004d2aa..618ff416a 100644 --- a/backend/bracket/logic/sso.py +++ b/backend/bracket/logic/sso.py @@ -22,6 +22,30 @@ async def get_discovery_document(discovery_url: str) -> DiscoveryDocument: } +async def build_openid_sso(sso_config: SSOConfig) -> SSOBase: + assert sso_config.openid_discovery_url is not None, ( + "`openid_discovery_url` should be set for OpenID SSO" + ) + assert sso_config.openid_scopes is not None, "`openid_scopes` should be set for OpenID SSO" + + def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> OpenID: + return OpenID(display_name=response["sub"]) + + GenericSSO = create_provider( + name="oidc", + discovery_document=await get_discovery_document(sso_config.openid_discovery_url), + response_convertor=convert_openid, + ) + + return GenericSSO( + client_id=sso_config.client_id, + client_secret=sso_config.client_secret, + redirect_uri=sso_config.redirect_uri, + allow_insecure_http=sso_config.allow_insecure_http, + scope=sso_config.openid_scopes.split(","), + ) + + async def build_sso(sso_config: SSOConfig) -> SSOBase: match sso_config.provider: case SSOProvider.google: @@ -39,29 +63,7 @@ async def build_sso(sso_config: SSOConfig) -> SSOBase: allow_insecure_http=sso_config.allow_insecure_http, ) case SSOProvider.openid: - assert sso_config.openid_discovery_url is not None, ( - "`openid_discovery_url` should be set for OpenID SSO" - ) - assert sso_config.openid_scopes is not None, ( - "`openid_scopes` should be set for OpenID SSO" - ) - - def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> OpenID: - return OpenID(display_name=response["sub"]) - - GenericSSO = create_provider( - name="oidc", - discovery_document=await get_discovery_document(sso_config.openid_discovery_url), - response_convertor=convert_openid, - ) - - return GenericSSO( - client_id=sso_config.client_id, - client_secret=sso_config.client_secret, - redirect_uri=sso_config.redirect_uri, - allow_insecure_http=sso_config.allow_insecure_http, - scope=sso_config.openid_scopes.split(","), - ) + return await build_openid_sso(sso_config) @cache diff --git a/backend/bracket/routes/auth.py b/backend/bracket/routes/auth.py index 2d2c748fd..234424cc9 100644 --- a/backend/bracket/routes/auth.py +++ b/backend/bracket/routes/auth.py @@ -1,4 +1,3 @@ -import os from typing import Any import jwt From 8ee53a07e151521f830a562da6e7a06ad4b954ba Mon Sep 17 00:00:00 2001 From: Erik Vroon Date: Sun, 23 Feb 2025 13:26:08 +0100 Subject: [PATCH 3/5] fixup! fixup! Implement sso --- backend/bracket/logic/sso.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/backend/bracket/logic/sso.py b/backend/bracket/logic/sso.py index 618ff416a..4bf6d5b9d 100644 --- a/backend/bracket/logic/sso.py +++ b/backend/bracket/logic/sso.py @@ -22,7 +22,7 @@ async def get_discovery_document(discovery_url: str) -> DiscoveryDocument: } -async def build_openid_sso(sso_config: SSOConfig) -> SSOBase: +async def get_openid_provider(sso_config: SSOConfig) -> type[SSOBase]: assert sso_config.openid_discovery_url is not None, ( "`openid_discovery_url` should be set for OpenID SSO" ) @@ -31,41 +31,29 @@ async def build_openid_sso(sso_config: SSOConfig) -> SSOBase: def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> OpenID: return OpenID(display_name=response["sub"]) - GenericSSO = create_provider( + return create_provider( name="oidc", discovery_document=await get_discovery_document(sso_config.openid_discovery_url), response_convertor=convert_openid, + default_scope=sso_config.openid_scopes.split(","), ) - return GenericSSO( + +async def build_sso(sso_config: SSOConfig) -> SSOBase: + provider_class_lookup: dict[SSOProvider, type[SSOBase]] = { + SSOProvider.google: GoogleSSO, + SSOProvider.github: GithubSSO, + SSOProvider.openid: await get_openid_provider(sso_config), + } + + return provider_class_lookup[sso_config.provider]( client_id=sso_config.client_id, client_secret=sso_config.client_secret, redirect_uri=sso_config.redirect_uri, allow_insecure_http=sso_config.allow_insecure_http, - scope=sso_config.openid_scopes.split(","), ) -async def build_sso(sso_config: SSOConfig) -> SSOBase: - match sso_config.provider: - case SSOProvider.google: - return GoogleSSO( - client_id=sso_config.client_id, - client_secret=sso_config.client_secret, - redirect_uri=sso_config.redirect_uri, - allow_insecure_http=sso_config.allow_insecure_http, - ) - case SSOProvider.github: - return GithubSSO( - client_id=sso_config.client_id, - client_secret=sso_config.client_secret, - redirect_uri=sso_config.redirect_uri, - allow_insecure_http=sso_config.allow_insecure_http, - ) - case SSOProvider.openid: - return await build_openid_sso(sso_config) - - @cache async def get_sso_providers() -> dict[SSOID, SSOBase]: configs = [] From 8a7a1a5c8e78e4474f1a378feec8b062bf74a5e1 Mon Sep 17 00:00:00 2001 From: Erik Vroon Date: Tue, 25 Feb 2025 14:52:24 +0100 Subject: [PATCH 4/5] fixup! fixup! fixup! Implement sso --- backend/bracket/logic/sso.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/backend/bracket/logic/sso.py b/backend/bracket/logic/sso.py index 4bf6d5b9d..4a6eb9a58 100644 --- a/backend/bracket/logic/sso.py +++ b/backend/bracket/logic/sso.py @@ -1,14 +1,15 @@ -from functools import cache -from typing import Any +from typing import Any, assert_never import aiohttp -from fastapi_sso import GithubSSO, GoogleSSO, OpenID, SSOBase, create_provider +from fastapi_sso import GoogleSSO, OpenID, SSOBase, create_provider from fastapi_sso.sso.base import DiscoveryDocument from httpx import AsyncClient from bracket.config import config from bracket.models.sso import SSOID, SSOConfig, SSOProvider +providers_cache: dict[SSOID, SSOBase] | None = None + async def get_discovery_document(discovery_url: str) -> DiscoveryDocument: async with aiohttp.ClientSession() as session: @@ -40,13 +41,17 @@ def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> Ope async def build_sso(sso_config: SSOConfig) -> SSOBase: - provider_class_lookup: dict[SSOProvider, type[SSOBase]] = { - SSOProvider.google: GoogleSSO, - SSOProvider.github: GithubSSO, - SSOProvider.openid: await get_openid_provider(sso_config), - } + match sso_config.provider: + case SSOProvider.google: + provider: type[SSOBase] = GoogleSSO + case SSOProvider.github: + provider = GoogleSSO + case SSOProvider.openid: + provider = await get_openid_provider(sso_config) + case _ as fallback: + assert_never(fallback) - return provider_class_lookup[sso_config.provider]( + return provider( client_id=sso_config.client_id, client_secret=sso_config.client_secret, redirect_uri=sso_config.redirect_uri, @@ -54,8 +59,12 @@ async def build_sso(sso_config: SSOConfig) -> SSOBase: ) -@cache async def get_sso_providers() -> dict[SSOID, SSOBase]: + global providers_cache # noqa: PLW0603 + + if providers_cache is not None: + return providers_cache + configs = [] if ( config.sso_1_provider is not None @@ -68,9 +77,11 @@ async def get_sso_providers() -> dict[SSOID, SSOBase]: provider=config.sso_1_provider, client_id=config.sso_1_client_id, client_secret=config.sso_1_client_secret, - redirect_uri=f"{config.base_url}/sso-callback", + redirect_uri=f"{config.base_url}/sso-callback/1", allow_insecure_http=config.sso_1_allow_insecure_http_sso, ) ) - return {sso_config.id: await build_sso(sso_config) for sso_config in configs} + providers = {sso_config.id: await build_sso(sso_config) for sso_config in configs} + providers_cache = providers + return providers_cache From 144cddd182b1befbe5aa24c39febe38d1ea09753 Mon Sep 17 00:00:00 2001 From: Erik Vroon Date: Fri, 23 May 2025 17:00:10 +0200 Subject: [PATCH 5/5] fixup! fixup! fixup! fixup! Implement sso --- .gitignore | 3 +++ backend/bracket/config.py | 1 - backend/bracket/logic/sso.py | 2 -- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index be583b176..62fd981cb 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,6 @@ backend/yarn.lock backend/static /process-compose.yml + +.next +next-env.d.ts diff --git a/backend/bracket/config.py b/backend/bracket/config.py index 7064898b0..01f7c4619 100644 --- a/backend/bracket/config.py +++ b/backend/bracket/config.py @@ -44,7 +44,6 @@ class Config(BaseSettings): sso_1_provider: SSOProvider | None = None sso_1_client_id: str | None = None - sso_1_client_secret: str | None = None sso_1_allow_insecure_http_sso: bool = False sso_1_openid_discovery_url: str | None = None sso_1_openid_scopes: str | None = None diff --git a/backend/bracket/logic/sso.py b/backend/bracket/logic/sso.py index 4a6eb9a58..c4f01d24c 100644 --- a/backend/bracket/logic/sso.py +++ b/backend/bracket/logic/sso.py @@ -69,14 +69,12 @@ async def get_sso_providers() -> dict[SSOID, SSOBase]: if ( config.sso_1_provider is not None and config.sso_1_client_id is not None - and config.sso_1_client_secret is not None ): configs.append( SSOConfig( id=SSOID(1), provider=config.sso_1_provider, client_id=config.sso_1_client_id, - client_secret=config.sso_1_client_secret, redirect_uri=f"{config.base_url}/sso-callback/1", allow_insecure_http=config.sso_1_allow_insecure_http_sso, )