11"""Authentication module - Support for Basic, Token, PAT and other auth methods"""
22
33import base64
4- from typing import Dict , Optional , Tuple , Union
54from abc import ABC , abstractmethod
5+ from typing import Dict , Optional , Tuple , Union
66
77import aiohttp
88
1212
1313class AuthProvider (ABC ):
1414 """Authentication provider abstract base class"""
15-
15+
1616 @abstractmethod
1717 async def get_headers (self ) -> Dict [str , str ]:
1818 """Get authentication headers"""
1919 pass
20-
20+
2121 @abstractmethod
2222 async def refresh_if_needed (self ) -> bool :
2323 """If needed, refresh the authentication. Returns whether it was refreshed"""
2424 pass
25-
25+
2626 @abstractmethod
2727 async def is_valid (self ) -> bool :
2828 """Check if the authentication is valid"""
@@ -31,85 +31,85 @@ async def is_valid(self) -> bool:
3131
3232class BasicAuthProvider (AuthProvider ):
3333 """Basic authentication provider"""
34-
34+
3535 def __init__ (self , username : str , password : str ):
3636 self .username = username
3737 self .password = password
3838 self ._auth_header = self ._encode_basic_auth (username , password )
39-
39+
4040 @staticmethod
4141 def _encode_basic_auth (username : str , password : str ) -> str :
4242 """Encode Basic authentication"""
4343 credentials = f"{ username } :{ password } "
4444 encoded = base64 .b64encode (credentials .encode ('utf-8' )).decode ('ascii' )
4545 return f"Basic { encoded } "
46-
46+
4747 async def get_headers (self ) -> Dict [str , str ]:
4848 """Get authentication headers"""
4949 return {"Authorization" : self ._auth_header }
50-
50+
5151 async def refresh_if_needed (self ) -> bool :
5252 """Basic auth does not need to be refreshed"""
5353 return False
54-
54+
5555 async def is_valid (self ) -> bool :
5656 """Basic auth is always valid (assuming credentials are correct)"""
5757 return True
5858
5959
6060class TokenAuthProvider (AuthProvider ):
6161 """Token authentication provider"""
62-
62+
6363 def __init__ (self , token : str , token_type : str = "Bearer" ):
6464 self .token = token
6565 self .token_type = token_type
6666 self ._auth_header = f"{ token_type } { token } "
67-
67+
6868 async def get_headers (self ) -> Dict [str , str ]:
6969 """Get authentication headers"""
7070 return {"Authorization" : self ._auth_header }
71-
71+
7272 async def refresh_if_needed (self ) -> bool :
7373 """Token auth does not need to be refreshed (simple implementation)"""
7474 return False
75-
75+
7676 async def is_valid (self ) -> bool :
7777 """Token auth is always valid (assuming token is correct)"""
7878 return True
7979
8080
8181class PATAuthProvider (AuthProvider ):
8282 """Personal Access Token authentication provider"""
83-
83+
8484 def __init__ (self , pat_token : str ):
8585 self .pat_token = pat_token
8686 self ._auth_header = f"Bearer { pat_token } "
87-
87+
8888 async def get_headers (self ) -> Dict [str , str ]:
8989 """Get authentication headers"""
9090 return {"Authorization" : self ._auth_header }
91-
91+
9292 async def refresh_if_needed (self ) -> bool :
9393 """PAT auth does not need to be refreshed"""
9494 return False
95-
95+
9696 async def is_valid (self ) -> bool :
9797 """PAT auth is always valid (assuming token is correct)"""
9898 return True
9999
100100
101101class SessionAuthProvider (AuthProvider ):
102102 """Session authentication provider (supports JSESSIONID, etc.)"""
103-
103+
104104 def __init__ (self , session : aiohttp .ClientSession , base_url : str ):
105105 self .session = session
106106 self .base_url = base_url
107107 self ._authenticated = False
108-
108+
109109 async def login (self , username : str , password : str ) -> None :
110110 """Login to get a session"""
111111 login_url = f"{ self .base_url } /dhis-web-commons-security/login.action"
112-
112+
113113 async with self .session .post (
114114 login_url ,
115115 data = {
@@ -121,11 +121,11 @@ async def login(self, username: str, password: str) -> None:
121121 self ._authenticated = True
122122 else :
123123 raise AuthenticationError (f"Login failed with status { response .status } " )
124-
124+
125125 async def get_headers (self ) -> Dict [str , str ]:
126126 """Get authentication headers (session auth relies on cookies)"""
127127 return {}
128-
128+
129129 async def refresh_if_needed (self ) -> bool :
130130 """Check if session needs to be refreshed"""
131131 # Simple implementation: check the /api/me endpoint
@@ -138,7 +138,7 @@ async def refresh_if_needed(self) -> bool:
138138 except Exception :
139139 self ._authenticated = False
140140 return False
141-
141+
142142 async def is_valid (self ) -> bool :
143143 """Check if the session is valid"""
144144 return self ._authenticated
@@ -151,49 +151,49 @@ def create_auth_provider(
151151 base_url : Optional [str ] = None
152152) -> AuthProvider :
153153 """Factory function: create an authentication provider based on configuration"""
154-
154+
155155 if auth_method == AuthMethod .BASIC :
156156 if not isinstance (auth , tuple ) or len (auth ) != 2 :
157157 raise ValueError ("Basic authentication requires a (username, password) tuple" )
158158 return BasicAuthProvider (auth [0 ], auth [1 ])
159-
159+
160160 elif auth_method == AuthMethod .TOKEN :
161161 if not isinstance (auth , str ):
162162 raise ValueError ("Token authentication requires a string token" )
163163 return TokenAuthProvider (auth )
164-
164+
165165 elif auth_method == AuthMethod .PAT :
166166 if not isinstance (auth , str ):
167167 raise ValueError ("PAT authentication requires a string token" )
168168 return PATAuthProvider (auth )
169-
169+
170170 else :
171171 raise ValueError (f"Unsupported authentication method: { auth_method } " )
172172
173173
174174class AuthManager :
175175 """Authentication manager - manages authentication providers and refresh logic"""
176-
176+
177177 def __init__ (self , auth_provider : AuthProvider ):
178178 self .auth_provider = auth_provider
179179 self ._last_refresh_check = 0
180180 self ._refresh_interval = 300 # Check every 5 minutes
181-
181+
182182 async def get_auth_headers (self ) -> Dict [str , str ]:
183183 """Get authentication headers, refreshing if necessary"""
184184 import time
185-
185+
186186 current_time = time .time ()
187187 if current_time - self ._last_refresh_check > self ._refresh_interval :
188188 await self .auth_provider .refresh_if_needed ()
189189 self ._last_refresh_check = current_time
190-
190+
191191 return await self .auth_provider .get_headers ()
192-
192+
193193 async def validate_auth (self ) -> bool :
194194 """Validate if the authentication is valid"""
195195 return await self .auth_provider .is_valid ()
196-
196+
197197 async def force_refresh (self ) -> bool :
198198 """Force a refresh of the authentication"""
199199 return await self .auth_provider .refresh_if_needed ()
0 commit comments