@@ -250,6 +250,7 @@ def __init__(
250250 self .X_background = None
251251 if not hasattr (self , "model" ):
252252 self .model = model
253+ self .model_for_shap = unwrap_calibrated_classifier (self .model )
253254
254255 if safe_isinstance (model , "xgboost.core.Booster" ):
255256 raise ValueError (
@@ -337,7 +338,7 @@ def __init__(
337338 self .shap_kwargs = shap_kwargs or {}
338339
339340 if shap == "guess" :
340- shap_guess = guess_shap (self .model )
341+ shap_guess = guess_shap (self .model_for_shap )
341342 model_str = (
342343 str (type (self .model ))
343344 .replace ("'" , "" )
@@ -1230,13 +1231,14 @@ def shap_explainer(self):
12301231 if not hasattr (self , "_shap_explainer" ):
12311232 X_str = ", X_background" if self .X_background is not None else "X"
12321233 NoX_str = ", X_background" if self .X_background is not None else ""
1234+ model_for_shap = self .model_for_shap
12331235 if self .shap == "tree" :
12341236 logger .info (
12351237 "Generating self.shap_explainer = "
12361238 f"shap.TreeExplainer(model{ NoX_str } )"
12371239 )
12381240 # Fix XGBoost 3.1+ base_score string format before shap accesses it
1239- model_for_shap = self ._fix_xgboost_model_for_shap (self . model )
1241+ model_for_shap = self ._fix_xgboost_model_for_shap (model_for_shap )
12401242 self ._shap_explainer = shap .TreeExplainer (model_for_shap )
12411243 elif self .shap == "linear" :
12421244 if self .X_background is None :
@@ -1250,7 +1252,7 @@ def shap_explainer(self):
12501252 X_str ,
12511253 )
12521254 self ._shap_explainer = shap .LinearExplainer (
1253- self . model ,
1255+ model_for_shap ,
12541256 self .X_background if self .X_background is not None else self .X ,
12551257 )
12561258 elif self .shap == "deep" :
@@ -1263,7 +1265,7 @@ def shap_explainer(self):
12631265 UserWarning ,
12641266 )
12651267 self ._shap_explainer = shap .DeepExplainer (
1266- self . model ,
1268+ model_for_shap ,
12671269 self .X_background
12681270 if self .X_background is not None
12691271 else shap .sample (self .X , 5 ),
@@ -1280,7 +1282,7 @@ def shap_explainer(self):
12801282 import torch
12811283
12821284 self ._shap_explainer = shap .DeepExplainer (
1283- self . model .module_ ,
1285+ model_for_shap .module_ ,
12841286 torch .tensor (self .X_background .values )
12851287 if self .X_background is not None
12861288 else torch .tensor (shap .sample (self .X , 5 ).values ),
@@ -1327,7 +1329,7 @@ def model_predict(data_asarray):
13271329 "Please install a CUDA-enabled SHAP build that includes "
13281330 "GPUTree support."
13291331 )
1330- self ._shap_explainer = explainer_cls (self . model , X_data )
1332+ self ._shap_explainer = explainer_cls (model_for_shap , X_data )
13311333 return self ._shap_explainer
13321334
13331335 @insert_pos_label
@@ -2861,16 +2863,17 @@ def shap_explainer(self):
28612863 Taking into account model type and model_output
28622864 """
28632865 if not hasattr (self , "_shap_explainer" ):
2866+ model_for_shap = self .model_for_shap
28642867 model_str = (
2865- str (type (self . model ))
2868+ str (type (model_for_shap ))
28662869 .replace ("'" , "" )
28672870 .replace ("<" , "" )
28682871 .replace (">" , "" )
28692872 .split ("." )[- 1 ]
28702873 )
28712874 if self .shap == "tree" :
28722875 if safe_isinstance (
2873- self . model ,
2876+ model_for_shap ,
28742877 "XGBClassifier" ,
28752878 "LGBMClassifier" ,
28762879 "CatBoostClassifier" ,
@@ -2897,7 +2900,9 @@ def shap_explainer(self):
28972900 UserWarning ,
28982901 )
28992902 # Fix XGBoost 3.1+ base_score string format before shap accesses it
2900- model_for_shap = self ._fix_xgboost_model_for_shap (self .model )
2903+ model_for_shap = self ._fix_xgboost_model_for_shap (
2904+ model_for_shap
2905+ )
29012906 self ._shap_explainer = shap .TreeExplainer (
29022907 model_for_shap ,
29032908 self .X_background
@@ -2914,7 +2919,9 @@ def shap_explainer(self):
29142919 ", X_background" if self .X_background is not None else "" ,
29152920 )
29162921 # Fix XGBoost 3.1+ base_score string format before shap accesses it
2917- model_for_shap = self ._fix_xgboost_model_for_shap (self .model )
2922+ model_for_shap = self ._fix_xgboost_model_for_shap (
2923+ model_for_shap
2924+ )
29182925 self ._shap_explainer = shap .TreeExplainer (
29192926 model_for_shap , self .X_background
29202927 )
@@ -2929,7 +2936,7 @@ def shap_explainer(self):
29292936 ", X_background" if self .X_background is not None else "" ,
29302937 )
29312938 # Fix XGBoost 3.1+ base_score string format before shap accesses it
2932- model_for_shap = self ._fix_xgboost_model_for_shap (self . model )
2939+ model_for_shap = self ._fix_xgboost_model_for_shap (model_for_shap )
29332940 self ._shap_explainer = shap .TreeExplainer (
29342941 model_for_shap , self .X_background
29352942 )
@@ -2955,7 +2962,7 @@ def shap_explainer(self):
29552962 )
29562963
29572964 self ._shap_explainer = shap .LinearExplainer (
2958- self . model ,
2965+ model_for_shap ,
29592966 self .X_background if self .X_background is not None else self .X ,
29602967 )
29612968 elif self .shap == "deep" :
@@ -2968,7 +2975,7 @@ def shap_explainer(self):
29682975 UserWarning ,
29692976 )
29702977 self ._shap_explainer = shap .DeepExplainer (
2971- self . model ,
2978+ model_for_shap ,
29722979 self .X_background
29732980 if self .X_background is not None
29742981 else shap .sample (self .X , 5 ),
@@ -2985,7 +2992,7 @@ def shap_explainer(self):
29852992 UserWarning ,
29862993 )
29872994 self ._shap_explainer = shap .DeepExplainer (
2988- self . model .module_ ,
2995+ model_for_shap .module_ ,
29892996 torch .tensor (
29902997 self .X_background .values
29912998 if self .X_background is not None
0 commit comments