@@ -81,6 +81,7 @@ class AuthSamlProvider(models.Model):
8181 "auth.saml.attribute.mapping" ,
8282 "provider_id" ,
8383 string = "Attribute Mapping" ,
84+ copy = True ,
8485 )
8586 active = fields .Boolean (default = True )
8687 sequence = fields .Integer (index = True )
@@ -136,6 +137,21 @@ class AuthSamlProvider(models.Model):
136137 default = True ,
137138 help = "Whether metadata should be signed or not" ,
138139 )
140+ # User creation fields
141+ create_user = fields .Boolean (
142+ default = False ,
143+ help = "Create user if not found. The login and name will defaults to the SAML "
144+ "user matching attribute. Use the mapping attributes to change the value "
145+ "used. If a deactivated user has a matching saml uid, activate it rather than"
146+ "create a new one." ,
147+ )
148+ create_user_template_id = fields .Many2one (
149+ comodel_name = "res.users" ,
150+ # Template users, like base.default_user, are disabled by default so allow them
151+ domain = "[('active', 'in', (True, False))]" ,
152+ default = lambda self : self .env .ref ("base.default_user" ),
153+ help = "When creating user, this user is used as a template" ,
154+ )
139155
140156 @api .model
141157 def _sig_alg_selection (self ):
@@ -256,9 +272,7 @@ def _get_auth_request(self, extra_state=None, url_root=None):
256272 }
257273 state .update (extra_state )
258274
259- sig_alg = ds .SIG_RSA_SHA1
260- if self .sig_alg :
261- sig_alg = getattr (ds , self .sig_alg )
275+ sig_alg = getattr (ds , self .sig_alg )
262276
263277 saml_client = self ._get_client_for_provider (url_root )
264278 reqid , info = saml_client .prepare_for_authenticate (
@@ -272,6 +286,7 @@ def _get_auth_request(self, extra_state=None, url_root=None):
272286 for key , value in info ["headers" ]:
273287 if key == "Location" :
274288 redirect_url = value
289+ break
275290
276291 self ._store_outstanding_request (reqid )
277292
@@ -287,27 +302,15 @@ def _validate_auth_response(self, token: str, base_url: str = None):
287302 saml2 .entity .BINDING_HTTP_POST ,
288303 self ._get_outstanding_requests_dict (),
289304 )
290- matching_value = None
291-
292- if self .matching_attribute == "subject.nameId" :
293- matching_value = response .name_id .text
294- else :
295- attrs = response .get_identity ()
296-
297- for k , v in attrs .items ():
298- if k == self .matching_attribute :
299- matching_value = v
300- break
301-
302- if not matching_value :
303- raise Exception (
304- f"Matching attribute { self .matching_attribute } not found "
305- f"in user attrs: { attrs } "
306- )
307-
308- if matching_value and isinstance (matching_value , list ):
309- matching_value = next (iter (matching_value ), None )
310-
305+ try :
306+ matching_value = self ._get_attribute_value (
307+ response , self .matching_attribute
308+ )
309+ except KeyError :
310+ raise KeyError (
311+ f"Matching attribute { self .matching_attribute } not found "
312+ f"in user attrs: { response .get_identity ()} "
313+ ) from None
311314 if isinstance (matching_value , str ) and self .matching_attribute_to_lower :
312315 matching_value = matching_value .lower ()
313316
@@ -349,24 +352,59 @@ def _metadata_string(self, valid=None, base_url: str = None):
349352 sign = self .sign_metadata ,
350353 )
351354
355+ @staticmethod
356+ def _get_attribute_value (response , attribute_name : str ):
357+ """
358+
359+ :raise: KeyError if attribute is not in the response
360+ :param response:
361+ :param attribute_name:
362+ :return: value of the attribute. if the value is an empty list, return None
363+ otherwise return the first element of the list
364+ """
365+ if attribute_name == "subject.nameId" :
366+ return response .name_id .text
367+ attrs = response .get_identity ()
368+ attribute_value = attrs [attribute_name ]
369+ if isinstance (attribute_value , list ):
370+ attribute_value = next (iter (attribute_value ), None )
371+ return attribute_value
372+
352373 def _hook_validate_auth_response (self , response , matching_value ):
353374 self .ensure_one ()
354375 vals = {}
355- attrs = response .get_identity ()
356376
357377 for attribute in self .attribute_mapping_ids :
358- if attribute .attribute_name not in attrs :
359- _logger .debug (
378+ try :
379+ vals [attribute .field_name ] = self ._get_attribute_value (
380+ response , attribute .attribute_name
381+ )
382+ except KeyError :
383+ _logger .warning (
360384 "SAML attribute '%s' not found in response %s" ,
361385 attribute .attribute_name ,
362- attrs ,
386+ response . get_identity () ,
363387 )
364- continue
365388
366- attribute_value = attrs [attribute .attribute_name ]
367- if isinstance (attribute_value , list ):
368- attribute_value = attribute_value [0 ]
389+ return {"mapped_attrs" : vals }
369390
370- vals [attribute .field_name ] = attribute_value
391+ def _user_copy_defaults (self , validation ):
392+ """
393+ Returns defaults when copying the template user.
371394
372- return {"mapped_attrs" : vals }
395+ Can be overridden with extra information.
396+ :param validation: validation result
397+ :return: a dictionary for copying template user, empty to avoid copying
398+ """
399+ self .ensure_one ()
400+ if not self .create_user :
401+ return {}
402+ saml_uid = validation ["user_id" ]
403+ return {
404+ "name" : saml_uid ,
405+ "login" : saml_uid ,
406+ "active" : True ,
407+ # if signature is not provided by mapped_attrs, it will be computed
408+ # due to call to compute method in calling method.
409+ "signature" : None ,
410+ } | validation .get ("mapped_attrs" , {})
0 commit comments