Skip to content

Commit c0f76ab

Browse files
Merge pull request #556 from MaienM/oidc-claims-and-roles
Implement settings for OIDC claims & roles
2 parents 84d07d9 + d4f22eb commit c0f76ab

File tree

6 files changed

+336
-80
lines changed

6 files changed

+336
-80
lines changed

clients/clientapi.py

Lines changed: 110 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3586,6 +3586,12 @@ class OIDCProviderValues(BaseModel):
35863586
button_color: Optional[str] = "#000000"
35873587
button_text_color: Optional[str] = "#000000"
35883588
icon_svg: Optional[str] = None
3589+
name_claim: Optional[str] = None
3590+
email_claim: Optional[str] = None
3591+
username_claim: Optional[str] = None
3592+
roles_claim: Optional[str] = None
3593+
user_role: Optional[str] = None
3594+
admin_role: Optional[str] = None
35893595

35903596
@app.post("/api/data/add_oidc_provider")
35913597
async def api_add_oidc_provider(
@@ -3605,7 +3611,13 @@ async def api_add_oidc_provider(
36053611
provider_values.scope,
36063612
provider_values.button_color,
36073613
provider_values.button_text_color,
3608-
provider_values.icon_svg
3614+
provider_values.icon_svg,
3615+
provider_values.name_claim,
3616+
provider_values.email_claim,
3617+
provider_values.username_claim,
3618+
provider_values.roles_claim,
3619+
provider_values.user_role,
3620+
provider_values.admin_role
36093621
))
36103622
if not provider_id:
36113623
raise HTTPException(
@@ -6434,7 +6446,7 @@ async def oidc_callback(
64346446
)
64356447

64366448
# Unpack provider details
6437-
provider_id, client_id, client_secret, token_url, userinfo_url = provider
6449+
provider_id, client_id, client_secret, token_url, userinfo_url, name_claim, email_claim, username_claim, roles_claim, user_role, admin_role = provider
64386450

64396451
# Exchange authorization code for access token
64406452
async with httpx.AsyncClient() as client:
@@ -6479,7 +6491,7 @@ async def oidc_callback(
64796491

64806492
user_info = userinfo_response.json()
64816493
print(f"User info response: {user_info}")
6482-
email = user_info.get("email")
6494+
email = user_info.get(email_claim or "email")
64836495

64846496
parsed_url = urlparse(userinfo_url)
64856497
if not email and parsed_url.hostname == 'api.github.com':
@@ -6515,23 +6527,87 @@ async def oidc_callback(
65156527
url=f"{frontend_base}/oauth/callback?error=network_error"
65166528
)
65176529

6530+
# Verify access.
6531+
if roles_claim and user_role:
6532+
roles = user_info.get(roles_claim)
6533+
if not isinstance(roles, list):
6534+
print(f'Claim {roles_claim} should be a list of strings, but it is {roles}.')
6535+
return RedirectResponse(
6536+
url=f"{frontend_base}/oauth/callback?error=no_access&details=invalid_roles"
6537+
)
6538+
if user_role not in roles and not (admin_role and admin_role in roles):
6539+
print(f"User user role {user_role} {f'and admin role {admin_role}' if admin_role else ''} not in user's roles ({roles}), denying access.")
6540+
return RedirectResponse(
6541+
url=f"{frontend_base}/oauth/callback?error=no_access"
6542+
)
6543+
65186544
# Check if user exists
65196545
user = database_functions.functions.get_user_by_email(cnx, database_type, email)
65206546

65216547
# In your OIDC callback function, replace the user creation section with:
65226548

6549+
# Determine the user's information
6550+
fullname = user_info.get(name_claim or "name", "")
6551+
if username_claim and username_claim not in user_info:
6552+
print(f"Unable to determine username for user, username claim {username_claim} not present")
6553+
return RedirectResponse(
6554+
url=f"{frontend_base}/oauth/callback?error=user_creation_failed&details=username_claim_missing"
6555+
)
6556+
username = user_info.get(username_claim or "preferred_username")
6557+
65236558
if not user:
65246559
# Create new user
65256560
print(f"User with email {email} not found, creating new user")
6526-
fullname = user_info.get("name", "")
6527-
username = email.split("@")[0].lower()
6528-
base_username = username
6529-
counter = 1
6530-
max_attempts = 10
65316561

6532-
while counter <= max_attempts:
6562+
if username is None:
6563+
username = email.split("@")[0].lower()
6564+
base_username = username
6565+
counter = 1
6566+
max_attempts = 10
6567+
6568+
while counter <= max_attempts:
6569+
try:
6570+
print(f"Attempt {counter} to create user with base username: {base_username}")
6571+
user_id = database_functions.functions.create_oidc_user(
6572+
cnx, database_type, email, fullname, username
6573+
)
6574+
print(f"User created successfully with ID: {user_id}")
6575+
6576+
if not user_id:
6577+
print(f"ERROR: Invalid user_id returned: {user_id}")
6578+
return RedirectResponse(
6579+
url=f"{frontend_base}/oauth/callback?error=invalid_user_id"
6580+
)
6581+
6582+
print(f"Creating API key for user_id: {user_id}")
6583+
api_key = database_functions.functions.create_api_key(cnx, database_type, user_id)
6584+
print(f"API key created: {api_key[:5]}... (truncated for security)")
6585+
break
6586+
except UniqueViolation:
6587+
print(f"Username conflict with {username}, trying next variation")
6588+
username = f"{base_username}{counter}"
6589+
counter += 1
6590+
if counter > max_attempts:
6591+
print(f"Failed to create user after {max_attempts} attempts due to username conflicts")
6592+
return RedirectResponse(
6593+
url=f"{frontend_base}/oauth/callback?error=username_conflict"
6594+
)
6595+
except Exception as e:
6596+
print(f"Error during user creation: {str(e)}")
6597+
import traceback
6598+
print(f"Traceback: {traceback.format_exc()}")
6599+
return RedirectResponse(
6600+
url=f"{frontend_base}/oauth/callback?error=user_creation_failed&details={str(e)[:50]}"
6601+
)
6602+
else:
6603+
print("Failed to create user after maximum attempts")
6604+
return RedirectResponse(
6605+
url=f"{frontend_base}/oauth/callback?error=user_creation_failed"
6606+
)
6607+
6608+
else:
65336609
try:
6534-
print(f"Attempt {counter} to create user with base username: {base_username}")
6610+
print(f"Attempt to create user with username: {username}")
65356611
user_id = database_functions.functions.create_oidc_user(
65366612
cnx, database_type, email, fullname, username
65376613
)
@@ -6546,28 +6622,18 @@ async def oidc_callback(
65466622
print(f"Creating API key for user_id: {user_id}")
65476623
api_key = database_functions.functions.create_api_key(cnx, database_type, user_id)
65486624
print(f"API key created: {api_key[:5]}... (truncated for security)")
6549-
break
65506625
except UniqueViolation:
6551-
print(f"Username conflict with {username}, trying next variation")
6552-
username = f"{base_username}{counter}"
6553-
counter += 1
6554-
if counter > max_attempts:
6555-
print(f"Failed to create user after {max_attempts} attempts due to username conflicts")
6556-
return RedirectResponse(
6557-
url=f"{frontend_base}/oauth/callback?error=username_conflict"
6558-
)
6626+
print("Failed to create user due to username conflicts")
6627+
return RedirectResponse(
6628+
url=f"{frontend_base}/oauth/callback?error=username_conflict"
6629+
)
65596630
except Exception as e:
65606631
print(f"Error during user creation: {str(e)}")
65616632
import traceback
65626633
print(f"Traceback: {traceback.format_exc()}")
65636634
return RedirectResponse(
65646635
url=f"{frontend_base}/oauth/callback?error=user_creation_failed&details={str(e)[:50]}"
65656636
)
6566-
else:
6567-
print("Failed to create user after maximum attempts")
6568-
return RedirectResponse(
6569-
url=f"{frontend_base}/oauth/callback?error=user_creation_failed"
6570-
)
65716637

65726638
else:
65736639
# Existing user - retrieve their API key
@@ -6581,6 +6647,26 @@ async def oidc_callback(
65816647

65826648
print(f"API key retrieved: {api_key[:5]}... (truncated for security)")
65836649

6650+
# Update user info based on OIDC information.
6651+
database_functions.functions.set_fullname(cnx, database_type, user_id, fullname)
6652+
6653+
current_username = user[2] if isinstance(user, tuple) else user['username']
6654+
if username_claim and username != current_username:
6655+
if database_functions.functions.check_usernames(cnx, database_type, username):
6656+
print(f'Unable to update username for user {user_id} to match the username specified by the OIDC provider ({username}) as this is already in use by another user.')
6657+
else:
6658+
database_functions.functions.set_username(cnx, database_type, user_id, username)
6659+
6660+
# Update admin role based on OIDC roles.
6661+
if roles_claim and admin_role:
6662+
roles = user_info.get(roles_claim)
6663+
if not isinstance(roles, list):
6664+
print(f'Claim {roles_claim} should be a list of strings, but it is {roles}.')
6665+
return RedirectResponse(
6666+
url=f"{frontend_base}/oauth/callback?error=no_access&details=invalid_roles"
6667+
)
6668+
database_functions.functions.set_isadmin(cnx, database_type, user_id, admin_role in roles)
6669+
65846670
# Success case - redirect with API key
65856671
return RedirectResponse(url=f"{frontend_base}/oauth/callback?api_key={api_key}")
65866672

database_functions/functions.py

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -672,17 +672,21 @@ def add_oidc_provider(cnx, database_type, provider_values):
672672
INSERT INTO "OIDCProviders"
673673
(ProviderName, ClientID, ClientSecret, AuthorizationURL,
674674
TokenURL, UserInfoURL, ButtonText, Scope,
675-
ButtonColor, ButtonTextColor, IconSVG)
676-
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
675+
ButtonColor, ButtonTextColor, IconSVG, NameClaim, EmailClaim,
676+
UsernameClaim, RolesClaim, UserRole, AdminRole)
677+
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
678+
%s, %s, %s)
677679
RETURNING ProviderID
678680
"""
679681
else: # MySQL
680682
add_provider_query = """
681683
INSERT INTO OIDCProviders
682684
(ProviderName, ClientID, ClientSecret, AuthorizationURL,
683685
TokenURL, UserInfoURL, ButtonText, Scope,
684-
ButtonColor, ButtonTextColor, IconSVG)
685-
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
686+
ButtonColor, ButtonTextColor, IconSVG, NameClaim, EmailClaim,
687+
UsernameClaim, RolesClaim, UserRole, AdminRole)
688+
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
689+
%s, %s, %s)
686690
"""
687691
cursor.execute(add_provider_query, provider_values)
688692

@@ -734,16 +738,18 @@ def list_oidc_providers(cnx, database_type):
734738
if database_type == "postgresql":
735739
list_query = """
736740
SELECT ProviderID, ProviderName, ClientID, AuthorizationURL,
737-
TokenURL, UserInfoURL, ButtonText,
738-
Scope, ButtonColor, ButtonTextColor, IconSVG, Enabled, Created, Modified
741+
TokenURL, UserInfoURL, ButtonText, Scope, ButtonColor,
742+
ButtonTextColor, IconSVG, NameClaim, EmailClaim, UsernameClaim,
743+
RolesClaim, UserRole, AdminRole, Enabled, Created, Modified
739744
FROM "OIDCProviders"
740745
ORDER BY ProviderName
741746
"""
742747
else:
743748
list_query = """
744749
SELECT ProviderID, ProviderName, ClientID, AuthorizationURL,
745-
TokenURL, UserInfoURL, ButtonText,
746-
Scope, ButtonColor, ButtonTextColor, IconSVG, Enabled, Created, Modified
750+
TokenURL, UserInfoURL, ButtonText, Scope, ButtonColor,
751+
ButtonTextColor, IconSVG, NameClaim, EmailClaim, UsernameClaim,
752+
RolesClaim, UserRole, AdminRole, Enabled, Created, Modified
747753
FROM OIDCProviders
748754
ORDER BY ProviderName
749755
"""
@@ -777,6 +783,18 @@ def list_oidc_providers(cnx, database_type):
777783
normalized["button_text_color"] = value
778784
elif normalized_key == "iconsvg":
779785
normalized["icon_svg"] = value
786+
elif normalized_key == "nameclaim":
787+
normalized["name_claim"] = value
788+
elif normalized_key == "emailclaim":
789+
normalized["email_claim"] = value
790+
elif normalized_key == "usernameclaim":
791+
normalized["username_claim"] = value
792+
elif normalized_key == "rolesclaim":
793+
normalized["roles_claim"] = value
794+
elif normalized_key == "userrole":
795+
normalized["user_role"] = value
796+
elif normalized_key == "adminrole":
797+
normalized["admin_role"] = value
780798
else:
781799
normalized[normalized_key] = value
782800
providers.append(normalized)
@@ -794,9 +812,15 @@ def list_oidc_providers(cnx, database_type):
794812
'button_color': row[8],
795813
'button_text_color': row[9],
796814
'icon_svg': row[10],
797-
'enabled': row[11],
798-
'created': row[12],
799-
'modified': row[13]
815+
'name_claim': row[11],
816+
'email_claim': row[12],
817+
'username_claim': row[13],
818+
'roles_claim': row[14],
819+
'user_role': row[15],
820+
'admin_role': row[16],
821+
'enabled': row[17],
822+
'created': row[18],
823+
'modified': row[19]
800824
})
801825
else:
802826
columns = [col[0] for col in cursor.description]
@@ -827,6 +851,18 @@ def list_oidc_providers(cnx, database_type):
827851
normalized["button_text_color"] = value
828852
elif normalized_key == "iconsvg":
829853
normalized["icon_svg"] = value
854+
elif normalized_key == "nameclaim":
855+
normalized["name_claim"] = value
856+
elif normalized_key == "emailclaim":
857+
normalized["email_claim"] = value
858+
elif normalized_key == "usernameclaim":
859+
normalized["username_claim"] = value
860+
elif normalized_key == "rolesclaim":
861+
normalized["roles_claim"] = value
862+
elif normalized_key == "userrole":
863+
normalized["user_role"] = value
864+
elif normalized_key == "adminrole":
865+
normalized["admin_role"] = value
830866
elif normalized_key == "enabled":
831867
# Convert MySQL TINYINT to boolean
832868
normalized["enabled"] = bool(value)
@@ -14664,13 +14700,13 @@ def get_oidc_provider(cnx, database_type, client_id):
1466414700
try:
1466514701
if database_type == "postgresql":
1466614702
query = """
14667-
SELECT ProviderID, ClientID, ClientSecret, TokenURL, UserInfoURL
14703+
SELECT ProviderID, ClientID, ClientSecret, TokenURL, UserInfoURL, NameClaim, EmailClaim, UsernameClaim, RolesClaim, UserRole, AdminRole
1466814704
FROM "OIDCProviders"
1466914705
WHERE ClientID = %s AND Enabled = true
1467014706
"""
1467114707
else:
1467214708
query = """
14673-
SELECT ProviderID, ClientID, ClientSecret, TokenURL, UserInfoURL
14709+
SELECT ProviderID, ClientID, ClientSecret, TokenURL, UserInfoURL, NameClaim, EmailClaim, UsernameClaim, RolesClaim, UserRole, AdminRole
1467414710
FROM OIDCProviders
1467514711
WHERE ClientID = %s AND Enabled = true
1467614712
"""
@@ -14683,7 +14719,13 @@ def get_oidc_provider(cnx, database_type, client_id):
1468314719
result['clientid'],
1468414720
result['clientsecret'],
1468514721
result['tokenurl'],
14686-
result['userinfourl']
14722+
result['userinfourl'],
14723+
result['nameclaim'],
14724+
result['emailclaim'],
14725+
result['usernameclaim'],
14726+
result['rolesclaim'],
14727+
result['userrole'],
14728+
result['adminrole']
1468714729
)
1468814730
return result
1468914731
return None
@@ -14721,49 +14763,10 @@ def get_user_by_email(cnx, database_type, email):
1472114763
finally:
1472214764
cursor.close()
1472314765

14724-
def create_oidc_user(cnx, database_type, email, fullname, base_username):
14766+
def create_oidc_user(cnx, database_type, email, fullname, username):
1472514767
cursor = cnx.cursor()
1472614768
try:
14727-
print(f"Starting create_oidc_user for email: {email}, fullname: {fullname}, base_username: {base_username}")
14728-
# Check if username exists and find a unique one
14729-
username = base_username
14730-
counter = 1
14731-
while True:
14732-
# Check if username exists
14733-
check_query = """
14734-
SELECT COUNT(*) FROM "Users" WHERE Username = %s
14735-
""" if database_type == "postgresql" else """
14736-
SELECT COUNT(*) FROM Users WHERE Username = %s
14737-
"""
14738-
print(f"Checking if username '{username}' exists")
14739-
cursor.execute(check_query, (username,))
14740-
result = cursor.fetchone()
14741-
print(f"Username check result: {result}, type: {type(result)}")
14742-
14743-
count = 0
14744-
if isinstance(result, tuple):
14745-
count = result[0]
14746-
elif isinstance(result, dict):
14747-
count = result.get('count', 0)
14748-
else:
14749-
# Try to extract the count value safely
14750-
try:
14751-
count = int(result)
14752-
except (TypeError, ValueError):
14753-
print(f"Unable to extract count from result: {result}")
14754-
count = 1 # Assume username exists to be safe
14755-
14756-
print(f"Username count: {count}")
14757-
if count == 0:
14758-
print(f"Username '{username}' is unique, proceeding")
14759-
break # Username is unique
14760-
14761-
# Try with incremented counter
14762-
print(f"Username '{username}' already exists, trying next")
14763-
username = f"{base_username}{counter}"
14764-
counter += 1
14765-
if counter > 10: # Limit attempts
14766-
raise Exception("Could not find a unique username")
14769+
print(f"Starting create_oidc_user for email: {email}, fullname: {fullname}, username: {username}")
1476714770

1476814771
# Create a random salt using base64 (which is what Argon2 expects)
1476914772
salt = base64.b64encode(secrets.token_bytes(16)).decode('utf-8')

0 commit comments

Comments
 (0)