Skip to content

Commit bac8891

Browse files
committed
feat: Implement Shotgrid authentication with username/password
Signed-off-by: aviralgarg05 <gargaviral99@gmail.com>
1 parent b464c19 commit bac8891

File tree

6 files changed

+253
-27
lines changed

6 files changed

+253
-27
lines changed

backend/src/dna/prodtrack_providers/prodtrack_provider_base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def get_user_by_email(self, user_email: str) -> "User":
5151
"""
5252
raise NotImplementedError("Subclasses must implement this method.")
5353

54+
def get_user_by_login(self, login: str) -> "User":
55+
"""Get a user by their login/username."""
56+
raise NotImplementedError("Subclasses must implement this method.")
57+
5458
def get_projects_for_user(self, user_email: str) -> list["Project"]:
5559
"""Get projects accessible by a user.
5660
@@ -84,12 +88,17 @@ def get_versions_for_playlist(self, playlist_id: int) -> list["Version"]:
8488
"""
8589
raise NotImplementedError("Subclasses must implement this method.")
8690

91+
@staticmethod
92+
def authenticate_user(url: str, login: str, password: str) -> str:
93+
"""Authenticate a user and return a session token."""
94+
raise NotImplementedError("Subclasses must implement this method.")
95+
8796

88-
def get_prodtrack_provider() -> ProdtrackProviderBase:
97+
def get_prodtrack_provider(session_token: str | None = None) -> ProdtrackProviderBase:
8998
"""Get the production tracking provider."""
9099
from dna.prodtrack_providers.shotgrid import ShotgridProvider
91100

92101
provider_type = os.getenv("PRODTRACK_PROVIDER", "shotgrid")
93102
if provider_type == "shotgrid":
94-
return ShotgridProvider()
103+
return ShotgridProvider(session_token=session_token)
95104
raise ValueError(f"Unknown production tracking provider: {provider_type}")

backend/src/dna/prodtrack_providers/shotgrid.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
url: Optional[str] = None,
123123
script_name: Optional[str] = None,
124124
api_key: Optional[str] = None,
125+
session_token: Optional[str] = None,
125126
connect: bool = True,
126127
):
127128
"""Initialize the ShotGrid connection.
@@ -130,17 +131,22 @@ def __init__(
130131
url: ShotGrid server URL. Defaults to SHOTGRID_URL env var.
131132
script_name: API script name. Defaults to SHOTGRID_SCRIPT_NAME env var.
132133
api_key: API key for authentication. Defaults to SHOTGRID_API_KEY env var.
134+
session_token: Session token for user authentication.
133135
"""
134136
super().__init__()
135137

136138
self.url = url or os.getenv("SHOTGRID_URL")
137139
self.script_name = script_name or os.getenv("SHOTGRID_SCRIPT_NAME")
138140
self.api_key = api_key or os.getenv("SHOTGRID_API_KEY")
141+
self.session_token = session_token
139142

140-
if not all([self.url, self.script_name, self.api_key]):
143+
if not self.url:
144+
raise ValueError("ShotGrid URL not provided.")
145+
146+
if not self.session_token and not (self.script_name and self.api_key):
141147
raise ValueError(
142-
"ShotGrid credentials not provided. Set SHOTGRID_URL, "
143-
"SHOTGRID_SCRIPT_NAME, and SHOTGRID_API_KEY environment variables."
148+
"ShotGrid credentials not provided. Provide either session_token "
149+
"or (script_name and api_key)."
144150
)
145151

146152
self.sg = None
@@ -149,7 +155,11 @@ def __init__(
149155

150156
def _connect(self):
151157
"""Connect to ShotGrid."""
152-
self.sg = Shotgun(self.url, self.script_name, self.api_key)
158+
if self.session_token:
159+
# When using session token, we don't use script credentials
160+
self.sg = Shotgun(self.url, session_token=self.session_token)
161+
else:
162+
self.sg = Shotgun(self.url, self.script_name, self.api_key)
153163

154164
def _convert_sg_entity_to_dna_entity(
155165
self,
@@ -398,6 +408,35 @@ def get_user_by_email(self, user_email: str) -> User:
398408
sg_user, entity_mapping, "user", resolve_links=False
399409
)
400410

411+
def get_user_by_login(self, login: str) -> User:
412+
"""Get a user by their login/username.
413+
414+
Args:
415+
login: The login/username of the user
416+
417+
Returns:
418+
User entity
419+
420+
Raises:
421+
ValueError: If user is not found
422+
"""
423+
if not self.sg:
424+
raise ValueError("Not connected to ShotGrid")
425+
426+
sg_user = self.sg.find_one(
427+
"HumanUser",
428+
filters=[["login", "is", login]],
429+
fields=["id", "name", "email", "login"],
430+
)
431+
432+
if not sg_user:
433+
raise ValueError(f"User not found: {login}")
434+
435+
entity_mapping = FIELD_MAPPING["user"]
436+
return self._convert_sg_entity_to_dna_entity(
437+
sg_user, entity_mapping, "user", resolve_links=False
438+
)
439+
401440
def get_projects_for_user(self, user_email: str) -> list[Project]:
402441
"""Get projects accessible by a user.
403442
@@ -538,6 +577,41 @@ def get_versions_for_playlist(self, playlist_id: int) -> list[Version]:
538577
return versions
539578

540579

580+
@staticmethod
581+
def authenticate_user(url: str, login: str, password: str) -> str:
582+
"""Authenticate a user with ShotGrid and return a session token.
583+
584+
Args:
585+
url: ShotGrid server URL
586+
login: User login/username
587+
password: User password
588+
589+
Returns:
590+
Session token string
591+
592+
Raises:
593+
ValueError: If authentication fails
594+
"""
595+
try:
596+
# Shotgun.authenticate_human_user returns the user object, but we need the session_token.
597+
# However, the standard way to get a token is to just create a connection which validates creds.
598+
# But wait, Shotgun API structure specifically for auth:
599+
# We can use the simple authentication helper or instantiate to get token.
600+
# Actually, standard shotgun_api3 doesn't easily expose 'authenticate_human_user' to get a token string directly
601+
# without internals.
602+
# Let's instantiate a connection to verify and get session_token if available or standard auth flow.
603+
# The pattern usually is: sg = Shotgun(url, login=login, password=password) then sg.get_session_token().
604+
605+
# Note: shotgun_api3 v3.3.0+ supports `login` and `password` in constructor for script-based auth,
606+
# but for human user relying on session token:
607+
608+
sg = Shotgun(url, login=login, password=password)
609+
# This establishes connection. Now implementation detail: how to get the token?
610+
# The 'get_session_token()' method provides it.
611+
return sg.get_session_token()
612+
except Exception as e:
613+
raise ValueError(f"Authentication failed: {str(e)}")
614+
541615
def _get_dna_entity_type(sg_entity_type: str) -> str:
542616
"""Get the DNA entity type from the ShotGrid entity type."""
543617
for entity_type, entity_data in FIELD_MAPPING.items():

backend/src/main.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from functools import lru_cache
44
from typing import Annotated, cast
55

6-
from fastapi import Depends, FastAPI, HTTPException
6+
from fastapi import Depends, FastAPI, HTTPException, Header
77
from fastapi.middleware.cors import CORSMiddleware
8+
from pydantic import BaseModel
89

910
from dna.models import (
1011
Asset,
@@ -23,6 +24,13 @@
2324
ProdtrackProviderBase,
2425
get_prodtrack_provider,
2526
)
27+
from dna.prodtrack_providers.shotgrid import ShotgridProvider
28+
29+
30+
class LoginRequest(BaseModel):
31+
username: str
32+
password: str
33+
2634

2735
# API metadata for Swagger documentation
2836
API_TITLE = "DNA Backend"
@@ -47,6 +55,10 @@
4755

4856
# Define API tags for organizing endpoints
4957
tags_metadata = [
58+
{
59+
"name": "Auth",
60+
"description": "Authentication endpoints",
61+
},
5062
{
5163
"name": "Health",
5264
"description": "Health check and status endpoints",
@@ -126,17 +138,69 @@
126138
# -----------------------------------------------------------------------------
127139

128140

129-
@lru_cache
130-
def get_prodtrack_provider_cached() -> ProdtrackProviderBase:
131-
"""Get or create the production tracking provider singleton."""
132-
return get_prodtrack_provider()
141+
def get_token_header(authorization: Annotated[str | None, Header()] = None) -> str | None:
142+
"""Extract token from Authorization header."""
143+
if not authorization:
144+
return None
145+
if authorization.startswith("Bearer "):
146+
return authorization.split(" ")[1]
147+
return authorization
148+
149+
150+
def get_prodtrack_provider_dep(
151+
token: Annotated[str | None, Depends(get_token_header)],
152+
) -> ProdtrackProviderBase:
153+
"""Get the production tracking provider with user session."""
154+
return get_prodtrack_provider(session_token=token)
133155

134156

135157
ProdtrackProviderDep = Annotated[
136-
ProdtrackProviderBase, Depends(get_prodtrack_provider_cached)
158+
ProdtrackProviderBase, Depends(get_prodtrack_provider_dep)
137159
]
138160

139161

162+
# -----------------------------------------------------------------------------
163+
# Auth endpoints
164+
# -----------------------------------------------------------------------------
165+
166+
167+
@app.post(
168+
"/auth/login",
169+
tags=["Auth"],
170+
summary="Login to Production Tracking",
171+
description="Authenticate with username and password to get a session token.",
172+
)
173+
async def login(request: LoginRequest):
174+
"""Login to ShotGrid."""
175+
try:
176+
# We need a provider instance to access the static method if we want to keep it clean,
177+
# or just import the class. We imported ShotgridProvider above.
178+
# But we need the URL from the environment or default provider.
179+
# Let's instantiate a default provider to get config, or just use the class method
180+
# and assume env vars are set for URL if not passed?
181+
# The static method requires URL.
182+
183+
# Helper to get base URL
184+
import os
185+
url = os.getenv("SHOTGRID_URL")
186+
if not url:
187+
raise HTTPException(status_code=500, detail="SHOTGRID_URL not configured")
188+
189+
token = ShotgridProvider.authenticate_user(url, request.username, request.password)
190+
191+
# Create a provider with this token to fetch the user details (email)
192+
provider = ShotgridProvider(url=url, session_token=token)
193+
user = provider.get_user_by_login(request.username)
194+
195+
if not user.email:
196+
raise HTTPException(status_code=400, detail="User has no email address configured")
197+
198+
return {"token": token, "email": user.email}
199+
except ValueError as e:
200+
raise HTTPException(status_code=401, detail=str(e))
201+
202+
203+
140204
# -----------------------------------------------------------------------------
141205
# Health endpoints
142206
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)