@@ -81,6 +81,7 @@ class AuthSamlProvider(models.Model):
8181 "auth.saml.attribute.mapping" ,
8282 "provider_id" ,
8383 string = "Attribute Mapping" ,
84+ copy = True ,
8485 )
8586 active = fields .Boolean (default = True )
8687 sequence = fields .Integer (index = True )
@@ -136,6 +137,20 @@ class AuthSamlProvider(models.Model):
136137 default = True ,
137138 help = "Whether metadata should be signed or not" ,
138139 )
140+ # User creation fields
141+ create_user = fields .Boolean (
142+ default = False ,
143+ help = "Create user if not found. The login and name will defaults to the SAML "
144+ "user matching attribute. Use the mapping attributes to change the value "
145+ "used." ,
146+ )
147+ create_user_template_id = fields .Many2one (
148+ comodel_name = "res.users" ,
149+ # Template users, like base.default_user, are disabled by default so allow them
150+ domain = "[('active', 'in', (True, False))]" ,
151+ default = lambda self : self .env .ref ("base.default_user" ),
152+ help = "When creating user, this user is used as a template" ,
153+ )
139154
140155 @api .model
141156 def _sig_alg_selection (self ):
@@ -256,9 +271,7 @@ def _get_auth_request(self, extra_state=None, url_root=None):
256271 }
257272 state .update (extra_state )
258273
259- sig_alg = ds .SIG_RSA_SHA1
260- if self .sig_alg :
261- sig_alg = getattr (ds , self .sig_alg )
274+ sig_alg = getattr (ds , self .sig_alg )
262275
263276 saml_client = self ._get_client_for_provider (url_root )
264277 reqid , info = saml_client .prepare_for_authenticate (
@@ -272,6 +285,7 @@ def _get_auth_request(self, extra_state=None, url_root=None):
272285 for key , value in info ["headers" ]:
273286 if key == "Location" :
274287 redirect_url = value
288+ break
275289
276290 self ._store_outstanding_request (reqid )
277291
@@ -287,27 +301,15 @@ def _validate_auth_response(self, token: str, base_url: str = None):
287301 saml2 .entity .BINDING_HTTP_POST ,
288302 self ._get_outstanding_requests_dict (),
289303 )
290- matching_value = None
291-
292- if self .matching_attribute == "subject.nameId" :
293- matching_value = response .name_id .text
294- else :
295- attrs = response .get_identity ()
296-
297- for k , v in attrs .items ():
298- if k == self .matching_attribute :
299- matching_value = v
300- break
301-
302- if not matching_value :
303- raise Exception (
304- f"Matching attribute { self .matching_attribute } not found "
305- f"in user attrs: { attrs } "
306- )
307-
308- if matching_value and isinstance (matching_value , list ):
309- matching_value = next (iter (matching_value ), None )
310-
304+ try :
305+ matching_value = self ._get_attribute_value (
306+ response , self .matching_attribute
307+ )
308+ except KeyError :
309+ raise KeyError (
310+ f"Matching attribute { self .matching_attribute } not found "
311+ f"in user attrs: { response .get_identity ()} "
312+ ) from None
311313 if isinstance (matching_value , str ) and self .matching_attribute_to_lower :
312314 matching_value = matching_value .lower ()
313315
@@ -349,24 +351,59 @@ def _metadata_string(self, valid=None, base_url: str = None):
349351 sign = self .sign_metadata ,
350352 )
351353
354+ @staticmethod
355+ def _get_attribute_value (response , attribute_name : str ):
356+ """
357+
358+ :raise: KeyError if attribute is not in the response
359+ :param response:
360+ :param attribute_name:
361+ :return: value of the attribut. if the value is an empty list, return None
362+ otherwise return the first element of the list
363+ """
364+ if attribute_name == "subject.nameId" :
365+ return response .name_id .text
366+ attrs = response .get_identity ()
367+ attribute_value = attrs [attribute_name ]
368+ if isinstance (attribute_value , list ):
369+ attribute_value = next (iter (attribute_value ), None )
370+ return attribute_value
371+
352372 def _hook_validate_auth_response (self , response , matching_value ):
353373 self .ensure_one ()
354374 vals = {}
355- attrs = response .get_identity ()
356375
357376 for attribute in self .attribute_mapping_ids :
358- if attribute .attribute_name not in attrs :
359- _logger .debug (
377+ try :
378+ vals [attribute .field_name ] = self ._get_attribute_value (
379+ response , attribute .attribute_name
380+ )
381+ except KeyError :
382+ _logger .warning (
360383 "SAML attribute '%s' found in response %s" ,
361384 attribute .attribute_name ,
362- attrs ,
385+ response . get_identity () ,
363386 )
364- continue
365387
366- attribute_value = attrs [attribute .attribute_name ]
367- if isinstance (attribute_value , list ):
368- attribute_value = attribute_value [0 ]
388+ return {"mapped_attrs" : vals }
369389
370- vals [attribute .field_name ] = attribute_value
390+ def _user_copy_defaults (self , validation ):
391+ """
392+ Returns defaults when copying the template user.
371393
372- return {"mapped_attrs" : vals }
394+ Can be overridden with extra information.
395+ :param validation: validation result
396+ :return: a dictionary for copying template user, empty to avoid copying
397+ """
398+ self .ensure_one ()
399+ if not self .create_user :
400+ return {}
401+ saml_uid = validation ["user_id" ]
402+ return {
403+ "name" : saml_uid ,
404+ "login" : saml_uid ,
405+ "active" : True ,
406+ # if signature is not provided by mapped_attrs, it will be computed
407+ # due to call to compute method in calling method.
408+ "signature" : None ,
409+ } | validation .get ("mapped_attrs" , {})
0 commit comments