|
| 1 | +import json |
| 2 | +import os |
| 3 | +import re |
| 4 | +import secrets |
| 5 | +import time |
| 6 | +from urllib.parse import parse_qs, urlencode |
| 7 | + |
| 8 | +import aiohttp |
| 9 | +from aiohttp import web |
| 10 | + |
| 11 | + |
| 12 | +def install(ctx): |
| 13 | + g_app = ctx.app |
| 14 | + |
| 15 | + auth_config_file = os.path.join(ctx.get_user_path(), "github_auth", "config.json") |
| 16 | + |
| 17 | + auth_config = None |
| 18 | + if os.path.exists(auth_config_file): |
| 19 | + try: |
| 20 | + with open(auth_config_file, encoding="utf-8") as f: |
| 21 | + auth_config = json.load(f) |
| 22 | + if "enabled" in auth_config and not auth_config["enabled"]: |
| 23 | + ctx.log("GitHub Auth is disabled in config") |
| 24 | + auth_config = None |
| 25 | + except Exception as e: |
| 26 | + ctx.err("Failed to load GitHub auth config", e) |
| 27 | + else: |
| 28 | + ctx.dbg(f"GitHub Auth config file '{auth_config_file}' not found") |
| 29 | + |
| 30 | + if not auth_config: |
| 31 | + # don't load extension if auth_config is not found or is disabled |
| 32 | + ctx.disabled = True |
| 33 | + return |
| 34 | + |
| 35 | + client_id = auth_config.get("client_id", "") |
| 36 | + client_secret = auth_config.get("client_secret", "") |
| 37 | + redirect_uri = auth_config.get("redirect_uri", "") |
| 38 | + restrict_to = auth_config.get("restrict_to", "") |
| 39 | + |
| 40 | + # Expand environment variables |
| 41 | + if client_id.startswith("$"): |
| 42 | + client_id = client_id[1:] |
| 43 | + if client_secret.startswith("$"): |
| 44 | + client_secret = client_secret[1:] |
| 45 | + client_secret = os.getenv(client_secret) |
| 46 | + if redirect_uri.startswith("$"): |
| 47 | + redirect_uri = redirect_uri[1:] |
| 48 | + redirect_uri = os.getenv(redirect_uri) |
| 49 | + if restrict_to.startswith("$"): |
| 50 | + restrict_to = restrict_to[1:] |
| 51 | + restrict_to = os.getenv(restrict_to) |
| 52 | + |
| 53 | + # check if client_id is set |
| 54 | + if client_id == "GITHUB_CLIENT_ID": |
| 55 | + client_id = os.getenv(client_id) |
| 56 | + if client_secret == "GITHUB_CLIENT_SECRET": |
| 57 | + client_secret = os.getenv(client_secret) |
| 58 | + if restrict_to == "GITHUB_USERS": |
| 59 | + restrict_to = os.getenv(restrict_to) |
| 60 | + |
| 61 | + if not client_id or not redirect_uri or not client_secret: |
| 62 | + ctx.disabled = True |
| 63 | + ctx.log("GitHub OAuth client_id, client_secret and redirect_uri are not configured") |
| 64 | + return |
| 65 | + |
| 66 | + from llms.main import AuthProvider |
| 67 | + |
| 68 | + class GitHubAuthProvider(AuthProvider): |
| 69 | + def __init__(self, app): |
| 70 | + super().__init__(app) |
| 71 | + |
| 72 | + # Adding an Auth Provider forces Authentication to be enabled |
| 73 | + auth_provider = GitHubAuthProvider(g_app) |
| 74 | + g_app.auth_providers.append(auth_provider) |
| 75 | + |
| 76 | + # OAuth handlers |
| 77 | + async def github_auth_handler(request): |
| 78 | + # Generate CSRF state token |
| 79 | + state = secrets.token_urlsafe(32) |
| 80 | + ctx.oauth_states[state] = {"created": time.time(), "redirect_uri": redirect_uri} |
| 81 | + |
| 82 | + # Clean up old states (older than 10 minutes) |
| 83 | + current_time = time.time() |
| 84 | + expired_states = [s for s, data in ctx.oauth_states.items() if current_time - data["created"] > 600] |
| 85 | + for s in expired_states: |
| 86 | + del ctx.oauth_states[s] |
| 87 | + |
| 88 | + # Build GitHub authorization URL |
| 89 | + params = { |
| 90 | + "client_id": client_id, |
| 91 | + "redirect_uri": redirect_uri, |
| 92 | + "state": state, |
| 93 | + "scope": "read:user user:email", |
| 94 | + } |
| 95 | + auth_url = f"https://github.com/login/oauth/authorize?{urlencode(params)}" |
| 96 | + |
| 97 | + return web.HTTPFound(auth_url) |
| 98 | + |
| 99 | + def validate_user(github_username): |
| 100 | + # If restrict_to is configured, validate the user |
| 101 | + if restrict_to: |
| 102 | + # Parse allowed users (comma or space delimited) |
| 103 | + allowed_users = [u.strip() for u in re.split(r"[,\s]+", restrict_to) if u.strip()] |
| 104 | + |
| 105 | + # Check if user is in the allowed list |
| 106 | + if not github_username or github_username not in allowed_users: |
| 107 | + ctx.log(f"Access denied for user: {github_username}. Not in allowed list: {allowed_users}") |
| 108 | + return web.Response( |
| 109 | + text=f"Access denied. User '{github_username}' is not authorized to access this application.", |
| 110 | + status=403, |
| 111 | + ) |
| 112 | + return None |
| 113 | + |
| 114 | + async def github_callback_handler(request): |
| 115 | + """Handle GitHub OAuth callback""" |
| 116 | + code = request.query.get("code") |
| 117 | + state = request.query.get("state") |
| 118 | + |
| 119 | + # Handle malformed URLs where query params are appended with & instead of ? |
| 120 | + if not code and "tail" in request.match_info: |
| 121 | + tail = request.match_info["tail"] |
| 122 | + if tail.startswith("&"): |
| 123 | + params = parse_qs(tail[1:]) |
| 124 | + code = params.get("code", [None])[0] |
| 125 | + state = params.get("state", [None])[0] |
| 126 | + |
| 127 | + if not code or not state: |
| 128 | + return web.Response(text="Missing code or state parameter", status=400) |
| 129 | + |
| 130 | + # Verify state token (CSRF protection) |
| 131 | + if state not in ctx.oauth_states: |
| 132 | + return web.Response(text="Invalid state parameter", status=400) |
| 133 | + |
| 134 | + ctx.oauth_states.pop(state) |
| 135 | + |
| 136 | + # Exchange code for access token |
| 137 | + async with aiohttp.ClientSession() as session: |
| 138 | + token_url = "https://github.com/login/oauth/access_token" |
| 139 | + token_data = { |
| 140 | + "client_id": client_id, |
| 141 | + "client_secret": client_secret, |
| 142 | + "code": code, |
| 143 | + "redirect_uri": redirect_uri, |
| 144 | + } |
| 145 | + headers = {"Accept": "application/json"} |
| 146 | + |
| 147 | + async with session.post(token_url, data=token_data, headers=headers) as resp: |
| 148 | + token_response = await resp.json() |
| 149 | + access_token = token_response.get("access_token") |
| 150 | + |
| 151 | + if not access_token: |
| 152 | + error = token_response.get("error_description", "Failed to get access token") |
| 153 | + return web.json_response(ctx.create_error_response(f"OAuth error: {error}"), status=400) |
| 154 | + |
| 155 | + # Fetch user info |
| 156 | + user_url = "https://api.github.com/user" |
| 157 | + headers = {"Authorization": f"Bearer {access_token}", "Accept": "application/json"} |
| 158 | + |
| 159 | + async with session.get(user_url, headers=headers) as resp: |
| 160 | + user_data = await resp.json() |
| 161 | + |
| 162 | + # Validate user |
| 163 | + error_response = validate_user(user_data.get("login", "")) |
| 164 | + if error_response: |
| 165 | + return error_response |
| 166 | + |
| 167 | + # Create session |
| 168 | + session_token = secrets.token_urlsafe(32) |
| 169 | + ctx.sessions[session_token] = { |
| 170 | + "userId": str(user_data.get("id", "")), |
| 171 | + "userName": user_data.get("login", ""), |
| 172 | + "displayName": user_data.get("name", ""), |
| 173 | + "profileUrl": user_data.get("avatar_url", ""), |
| 174 | + "email": user_data.get("email", ""), |
| 175 | + "created": time.time(), |
| 176 | + } |
| 177 | + |
| 178 | + # Redirect to UI with session token |
| 179 | + response = web.HTTPFound(f"/?session={session_token}") |
| 180 | + response.set_cookie("llms-token", session_token, httponly=True, path="/", max_age=86400) |
| 181 | + return response |
| 182 | + |
| 183 | + async def session_handler(request): |
| 184 | + """Validate and return session info""" |
| 185 | + session_token = auth_provider.get_session_token(request) |
| 186 | + |
| 187 | + if not session_token or session_token not in ctx.sessions: |
| 188 | + return web.json_response(ctx.create_error_response("Invalid or expired session"), status=401) |
| 189 | + |
| 190 | + session_data = ctx.sessions[session_token] |
| 191 | + |
| 192 | + # Clean up old sessions (older than 24 hours) |
| 193 | + current_time = time.time() |
| 194 | + expired_sessions = [token for token, data in ctx.sessions.items() if current_time - data["created"] > 86400] |
| 195 | + for token in expired_sessions: |
| 196 | + del ctx.sessions[token] |
| 197 | + |
| 198 | + return web.json_response({**session_data, "sessionToken": session_token}) |
| 199 | + |
| 200 | + async def logout_handler(request): |
| 201 | + """End OAuth session""" |
| 202 | + session_token = auth_provider.get_session_token(request) |
| 203 | + |
| 204 | + if session_token and session_token in g_app.sessions: |
| 205 | + del g_app.sessions[session_token] |
| 206 | + |
| 207 | + response = web.json_response({"success": True}) |
| 208 | + response.del_cookie("llms-token") |
| 209 | + return response |
| 210 | + |
| 211 | + async def auth_handler(request): |
| 212 | + """Check authentication status and return user info""" |
| 213 | + # Check for OAuth session token |
| 214 | + session_token = auth_provider.get_session_token(request) |
| 215 | + |
| 216 | + if session_token and session_token in g_app.sessions: |
| 217 | + session_data = g_app.sessions[session_token] |
| 218 | + return web.json_response( |
| 219 | + { |
| 220 | + "userId": session_data.get("userId", ""), |
| 221 | + "userName": session_data.get("userName", ""), |
| 222 | + "displayName": session_data.get("displayName", ""), |
| 223 | + "profileUrl": session_data.get("profileUrl", ""), |
| 224 | + "authProvider": "github", |
| 225 | + } |
| 226 | + ) |
| 227 | + |
| 228 | + # Check for API key in Authorization header |
| 229 | + # auth_header = request.headers.get('Authorization', '') |
| 230 | + # if auth_header.startswith('Bearer '): |
| 231 | + # # For API key auth, return a basic response |
| 232 | + # # You can customize this based on your API key validation logic |
| 233 | + # api_key = auth_header[7:] |
| 234 | + # if api_key: # Add your API key validation logic here |
| 235 | + # return web.json_response({ |
| 236 | + # "userId": "1", |
| 237 | + # "userName": "apiuser", |
| 238 | + # "displayName": "API User", |
| 239 | + # "profileUrl": "", |
| 240 | + # "authProvider": "apikey" |
| 241 | + # }) |
| 242 | + |
| 243 | + # Not authenticated - return error in expected format |
| 244 | + return web.json_response(g_app.error_auth_required, status=401) |
| 245 | + |
| 246 | + ctx.add_get("/auth", auth_handler) |
| 247 | + ctx.add_get("/auth/github", github_auth_handler) |
| 248 | + ctx.add_get("/auth/github/callback", github_callback_handler) |
| 249 | + ctx.add_get("/auth/github/callback{tail:.*}", github_callback_handler) |
| 250 | + ctx.add_get("/auth/session", session_handler) |
| 251 | + ctx.add_post("/auth/logout", logout_handler) |
| 252 | + |
| 253 | + |
| 254 | +__install__ = install |
0 commit comments