Skip to content

Commit 1a5fc74

Browse files
authored
Support CalibratedClassifierCV in SHAP (#332)
1 parent 96e19ea commit 1a5fc74

File tree

5 files changed

+59
-17
lines changed

5 files changed

+59
-17
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Fix FeatureInputComponent range calculation for boolean columns (avoid np.round on bools) and add a regression test.
88
- Ensure save_html includes custom tabs by providing a static-export fallback for tabs without a to_html implementation.
99
- Support string class labels in ClassifierExplainer by preserving label mappings and avoiding float casts.
10+
- Support CalibratedClassifierCV by using its fitted base estimator for SHAP (avoids falling back to kernel).
1011

1112
### Improvements
1213
- Replace print statements with standard logging and warnings; progress messages are now INFO-level and user-actionable guidance uses warnings. A one-time warning is emitted if logging is not configured, with instructions to call `enable_default_logging()`.

TODO.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
- Rules: link an issue when possible; include size S/M/L; mark blockers.
77

88
**Now**
9-
- [M/L][Explainers][#279] support CalibratedClassifierCV (unwrap estimator for SHAP; update logic and tests).
9+
- [S/M][Components][#277] whatif input range/rounding customization.
1010

1111
**Next**
12-
- [S/M][Components][#277] whatif input range/rounding customization.
13-
- [S/M][Explainers][#274] support string labels without float casts.
1412
- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation.
1513
- [S][Explainers][#270] Autogluon integration (coerce predict_proba to ndarray).
1614
- [M][Hub][#269] add_dashboard endpoint fails after first request (Flask blueprint lifecycle).

explainerdashboard/explainer_methods.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"IndexNotFoundError",
33
"append_dict_to_df",
44
"safe_isinstance",
5+
"unwrap_calibrated_classifier",
56
"align_categorical_dtypes",
67
"guess_shap",
78
"mape_score",
@@ -296,6 +297,27 @@ def safe_isinstance(obj, *instance_str):
296297
return False
297298

298299

300+
def unwrap_calibrated_classifier(model):
301+
"""Return the fitted base estimator for a CalibratedClassifierCV model."""
302+
if not safe_isinstance(model, "CalibratedClassifierCV"):
303+
return model
304+
305+
calibrated_classifiers = getattr(model, "calibrated_classifiers_", None)
306+
if calibrated_classifiers:
307+
calibrated = calibrated_classifiers[0]
308+
for attr in ("estimator", "base_estimator"):
309+
estimator = getattr(calibrated, attr, None)
310+
if estimator is not None:
311+
return estimator
312+
313+
for attr in ("estimator", "base_estimator"):
314+
estimator = getattr(model, attr, None)
315+
if estimator is not None:
316+
return estimator
317+
318+
return model
319+
320+
299321
def align_categorical_dtypes(
300322
df_target: pd.DataFrame,
301323
df_reference: pd.DataFrame,
@@ -332,6 +354,8 @@ def guess_shap(model):
332354
Returns:
333355
str: {'tree', 'linear', None}
334356
"""
357+
model = unwrap_calibrated_classifier(model)
358+
335359
tree_models = [
336360
"RandomForestClassifier",
337361
"RandomForestRegressor",

explainerdashboard/explainers.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def __init__(
250250
self.X_background = None
251251
if not hasattr(self, "model"):
252252
self.model = model
253+
self.model_for_shap = unwrap_calibrated_classifier(self.model)
253254

254255
if safe_isinstance(model, "xgboost.core.Booster"):
255256
raise ValueError(
@@ -337,7 +338,7 @@ def __init__(
337338
self.shap_kwargs = shap_kwargs or {}
338339

339340
if shap == "guess":
340-
shap_guess = guess_shap(self.model)
341+
shap_guess = guess_shap(self.model_for_shap)
341342
model_str = (
342343
str(type(self.model))
343344
.replace("'", "")
@@ -1230,13 +1231,14 @@ def shap_explainer(self):
12301231
if not hasattr(self, "_shap_explainer"):
12311232
X_str = ", X_background" if self.X_background is not None else "X"
12321233
NoX_str = ", X_background" if self.X_background is not None else ""
1234+
model_for_shap = self.model_for_shap
12331235
if self.shap == "tree":
12341236
logger.info(
12351237
"Generating self.shap_explainer = "
12361238
f"shap.TreeExplainer(model{NoX_str})"
12371239
)
12381240
# Fix XGBoost 3.1+ base_score string format before shap accesses it
1239-
model_for_shap = self._fix_xgboost_model_for_shap(self.model)
1241+
model_for_shap = self._fix_xgboost_model_for_shap(model_for_shap)
12401242
self._shap_explainer = shap.TreeExplainer(model_for_shap)
12411243
elif self.shap == "linear":
12421244
if self.X_background is None:
@@ -1250,7 +1252,7 @@ def shap_explainer(self):
12501252
X_str,
12511253
)
12521254
self._shap_explainer = shap.LinearExplainer(
1253-
self.model,
1255+
model_for_shap,
12541256
self.X_background if self.X_background is not None else self.X,
12551257
)
12561258
elif self.shap == "deep":
@@ -1263,7 +1265,7 @@ def shap_explainer(self):
12631265
UserWarning,
12641266
)
12651267
self._shap_explainer = shap.DeepExplainer(
1266-
self.model,
1268+
model_for_shap,
12671269
self.X_background
12681270
if self.X_background is not None
12691271
else shap.sample(self.X, 5),
@@ -1280,7 +1282,7 @@ def shap_explainer(self):
12801282
import torch
12811283

12821284
self._shap_explainer = shap.DeepExplainer(
1283-
self.model.module_,
1285+
model_for_shap.module_,
12841286
torch.tensor(self.X_background.values)
12851287
if self.X_background is not None
12861288
else torch.tensor(shap.sample(self.X, 5).values),
@@ -1327,7 +1329,7 @@ def model_predict(data_asarray):
13271329
"Please install a CUDA-enabled SHAP build that includes "
13281330
"GPUTree support."
13291331
)
1330-
self._shap_explainer = explainer_cls(self.model, X_data)
1332+
self._shap_explainer = explainer_cls(model_for_shap, X_data)
13311333
return self._shap_explainer
13321334

13331335
@insert_pos_label
@@ -2861,16 +2863,17 @@ def shap_explainer(self):
28612863
Taking into account model type and model_output
28622864
"""
28632865
if not hasattr(self, "_shap_explainer"):
2866+
model_for_shap = self.model_for_shap
28642867
model_str = (
2865-
str(type(self.model))
2868+
str(type(model_for_shap))
28662869
.replace("'", "")
28672870
.replace("<", "")
28682871
.replace(">", "")
28692872
.split(".")[-1]
28702873
)
28712874
if self.shap == "tree":
28722875
if safe_isinstance(
2873-
self.model,
2876+
model_for_shap,
28742877
"XGBClassifier",
28752878
"LGBMClassifier",
28762879
"CatBoostClassifier",
@@ -2897,7 +2900,9 @@ def shap_explainer(self):
28972900
UserWarning,
28982901
)
28992902
# Fix XGBoost 3.1+ base_score string format before shap accesses it
2900-
model_for_shap = self._fix_xgboost_model_for_shap(self.model)
2903+
model_for_shap = self._fix_xgboost_model_for_shap(
2904+
model_for_shap
2905+
)
29012906
self._shap_explainer = shap.TreeExplainer(
29022907
model_for_shap,
29032908
self.X_background
@@ -2914,7 +2919,9 @@ def shap_explainer(self):
29142919
", X_background" if self.X_background is not None else "",
29152920
)
29162921
# Fix XGBoost 3.1+ base_score string format before shap accesses it
2917-
model_for_shap = self._fix_xgboost_model_for_shap(self.model)
2922+
model_for_shap = self._fix_xgboost_model_for_shap(
2923+
model_for_shap
2924+
)
29182925
self._shap_explainer = shap.TreeExplainer(
29192926
model_for_shap, self.X_background
29202927
)
@@ -2929,7 +2936,7 @@ def shap_explainer(self):
29292936
", X_background" if self.X_background is not None else "",
29302937
)
29312938
# Fix XGBoost 3.1+ base_score string format before shap accesses it
2932-
model_for_shap = self._fix_xgboost_model_for_shap(self.model)
2939+
model_for_shap = self._fix_xgboost_model_for_shap(model_for_shap)
29332940
self._shap_explainer = shap.TreeExplainer(
29342941
model_for_shap, self.X_background
29352942
)
@@ -2955,7 +2962,7 @@ def shap_explainer(self):
29552962
)
29562963

29572964
self._shap_explainer = shap.LinearExplainer(
2958-
self.model,
2965+
model_for_shap,
29592966
self.X_background if self.X_background is not None else self.X,
29602967
)
29612968
elif self.shap == "deep":
@@ -2968,7 +2975,7 @@ def shap_explainer(self):
29682975
UserWarning,
29692976
)
29702977
self._shap_explainer = shap.DeepExplainer(
2971-
self.model,
2978+
model_for_shap,
29722979
self.X_background
29732980
if self.X_background is not None
29742981
else shap.sample(self.X, 5),
@@ -2985,7 +2992,7 @@ def shap_explainer(self):
29852992
UserWarning,
29862993
)
29872994
self._shap_explainer = shap.DeepExplainer(
2988-
self.model.module_,
2995+
model_for_shap.module_,
29892996
torch.tensor(
29902997
self.X_background.values
29912998
if self.X_background is not None

tests/test_classifier_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import plotly.graph_objects as go
99

10+
from sklearn.calibration import CalibratedClassifierCV
1011
from sklearn.ensemble import RandomForestClassifier
1112

1213
from explainerdashboard import ClassifierExplainer, ExplainerDashboard
@@ -110,6 +111,17 @@ def test_string_labels_supported(classifier_data):
110111
assert {"pred_proba", "y"}.issubset(lift_df.columns)
111112

112113

114+
def test_calibrated_classifiercv_uses_tree_shap(classifier_data):
115+
X_train, y_train, X_test, y_test = classifier_data
116+
base_estimator = RandomForestClassifier(n_estimators=25, random_state=0)
117+
model = CalibratedClassifierCV(estimator=base_estimator, cv=2)
118+
model.fit(X_train, y_train)
119+
120+
explainer = ClassifierExplainer(model, X_test, y_test)
121+
122+
assert explainer.shap == "tree"
123+
124+
113125
def test_row_from_input(precalculated_rf_classifier_explainer):
114126
input_row = precalculated_rf_classifier_explainer.get_row_from_input(
115127
precalculated_rf_classifier_explainer.X.iloc[[0]].values.tolist()

0 commit comments

Comments
 (0)