Skip to content

Commit ba1bd8c

Browse files
authored
Support DataFrame model outputs for predict and predict_proba (#334)
1 parent 6950db9 commit ba1bd8c

File tree

7 files changed

+86
-6
lines changed

7 files changed

+86
-6
lines changed

RELEASE_NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
### Bug Fixes
77
- Allow FeatureInputComponent (what-if inputs) to customize numeric ranges and rounding, and apply min/max/step to inputs.
8+
- Improve compatibility with AutoGluon/custom wrappers by coercing pandas `DataFrame` outputs from `predict_proba`/`predict` to numpy arrays before indexing in classifier/regression helper paths.
9+
- Harden one-vs-all scorer handling so `make_one_vs_all_scorer` also accepts classifiers whose `predict_proba` returns a pandas `DataFrame`.
810

911
## Version 0.5.6:
1012

TODO.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation.
1010

1111
**Next**
12-
- [S][Explainers][#270] Autogluon integration (coerce predict_proba to ndarray).
1312
- [M][Hub][#269] add_dashboard endpoint fails after first request (Flask blueprint lifecycle).
1413
- [M/L][Components][#262] add filters for random transaction selection in whatif tab.
1514
- [S][Methods][#220] get_contrib_df accepts list/array input.

explainerdashboard/explainer_methods.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,8 @@ def _scorer(clf, X, y):
829829
warnings.filterwarnings("ignore", category=UserWarning)
830830
y_pred = clf.predict_proba(X)
831831
warnings.filterwarnings("default", category=UserWarning)
832+
y_pred = _ensure_numeric_predictions(y_pred)
833+
y_pred = np.asarray(y_pred)
832834
score = sign * partial_metric(y, y_pred)
833835
return score
834836

explainerdashboard/explainers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -914,15 +914,19 @@ def get_col_value_plus_prediction(
914914
if self.is_classifier:
915915
if pos_label is None:
916916
pos_label = self.pos_label
917-
pred_probas_raw = self.model.predict_proba(model_input)[0]
917+
pred_probas_raw = self.model.predict_proba(model_input)
918918
pred_probas_raw = _ensure_numeric_predictions(pred_probas_raw)
919-
prediction = np.asarray(pred_probas_raw)[pos_label].squeeze()
919+
pred_probas = np.asarray(pred_probas_raw).squeeze()
920+
if pred_probas.ndim > 1:
921+
pred_probas = pred_probas[0]
922+
prediction = pred_probas[pos_label].squeeze()
920923
if self.model_output == "probability":
921924
prediction = 100 * prediction
922925
elif self.is_regression:
923-
pred_raw = self.model.predict(model_input)[0]
926+
pred_raw = self.model.predict(model_input)
924927
pred_raw = _ensure_numeric_predictions(pred_raw)
925-
prediction = np.asarray(pred_raw).squeeze()
928+
pred_array = np.asarray(pred_raw).squeeze()
929+
prediction = pred_array.flat[0]
926930
return col_value, prediction
927931
else:
928932
raise ValueError("You need to pass either index or X_row!")
@@ -3793,9 +3797,11 @@ def prediction_result_df(
37933797
X_row = X_cats_to_X(X_row, self.onehot_dict, self.X.columns)
37943798
if self.shap == "skorch":
37953799
X_row = X_row.values.astype("float32")
3796-
pred_probas_raw = self.model.predict_proba(X_row)[0, :]
3800+
pred_probas_raw = self.model.predict_proba(X_row)
37973801
pred_probas_raw = _ensure_numeric_predictions(pred_probas_raw)
37983802
pred_probas = np.asarray(pred_probas_raw).squeeze()
3803+
if pred_probas.ndim > 1:
3804+
pred_probas = pred_probas[0]
37993805

38003806
preds_df = pd.DataFrame(dict(label=self.labels, probability=pred_probas))
38013807
if logodds and all(preds_df.probability < 1 - np.finfo(np.float64).eps):

tests/test_classifier_base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@
1414
from explainerdashboard.explainer_methods import IndexNotFoundError
1515

1616

17+
class DataFramePredictProbaWrapper:
18+
def __init__(self, model):
19+
self.model = model
20+
self.classes_ = model.classes_
21+
22+
def predict(self, X):
23+
return self.model.predict(X)
24+
25+
def predict_proba(self, X):
26+
probas = self.model.predict_proba(X)
27+
return pd.DataFrame(
28+
probas, columns=self.classes_, index=getattr(X, "index", None)
29+
)
30+
31+
1732
def test_explainer_with_dataframe_y(fitted_rf_classifier_model, classifier_data):
1833
_, _, X_test, y_test = classifier_data
1934
explainer = ClassifierExplainer(
@@ -332,6 +347,20 @@ def test_prediction_result_df(precalculated_rf_classifier_explainer):
332347
assert isinstance(df, pd.DataFrame)
333348

334349

350+
def test_prediction_result_df_with_dataframe_predict_proba(
351+
fitted_rf_classifier_model, classifier_data
352+
):
353+
_, _, X_test, y_test = classifier_data
354+
wrapped_model = DataFramePredictProbaWrapper(fitted_rf_classifier_model)
355+
explainer = ClassifierExplainer(wrapped_model, X_test.head(50), y_test.head(50))
356+
357+
df = explainer.prediction_result_df(0)
358+
_, prediction = explainer.get_col_value_plus_prediction("Age", index=0)
359+
360+
assert isinstance(df, pd.DataFrame)
361+
assert np.isscalar(prediction)
362+
363+
335364
def test_pdp_df(precalculated_rf_classifier_explainer):
336365
assert isinstance(precalculated_rf_classifier_explainer.pdp_df("Age"), pd.DataFrame)
337366
assert isinstance(

tests/test_dtype_alignment.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import pandas as pd
33

44
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
5+
from sklearn.metrics import roc_auc_score
56
from sklearn.metrics import r2_score
67

78
from explainerdashboard.explainer_methods import (
89
align_categorical_dtypes,
910
get_pdp_df,
11+
make_one_vs_all_scorer,
1012
permutation_importances,
1113
)
1214

@@ -37,6 +39,12 @@ def predict_proba(self, X):
3739
return np.tile(np.array([0.2, 0.8]), (len(X), 1))
3840

3941

42+
class DataFrameProbaClassifier(ClassifierMixin, BaseEstimator):
43+
def predict_proba(self, X):
44+
probs = np.tile(np.array([0.2, 0.8]), (len(X), 1))
45+
return pd.DataFrame(probs, index=getattr(X, "index", None), columns=[0, 1])
46+
47+
4048
def test_permutation_importances_preserves_categorical_dtypes():
4149
X = pd.DataFrame(
4250
{
@@ -80,3 +88,14 @@ def test_align_categorical_dtypes_matches_reference():
8088
aligned = align_categorical_dtypes(target, reference)
8189

8290
assert aligned["cat"].dtype == reference["cat"].dtype
91+
92+
93+
def test_make_one_vs_all_scorer_accepts_dataframe_predict_proba():
94+
scorer = make_one_vs_all_scorer(roc_auc_score, pos_label=1)
95+
clf = DataFrameProbaClassifier()
96+
X = pd.DataFrame({"feature": [0, 1, 2, 3]})
97+
y = np.array([0, 1, 0, 1])
98+
99+
score = scorer(clf, X, y)
100+
101+
assert isinstance(score, float)

tests/test_regression_base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@
44

55
import plotly.graph_objects as go
66

7+
from explainerdashboard import RegressionExplainer
8+
9+
10+
class DataFramePredictWrapper:
11+
def __init__(self, model):
12+
self.model = model
13+
14+
def predict(self, X):
15+
preds = self.model.predict(X)
16+
return pd.DataFrame({"prediction": preds}, index=getattr(X, "index", None))
17+
718

819
def test_explainer_len(precalculated_rf_regression_explainer, testlen):
920
assert len(precalculated_rf_regression_explainer) == testlen
@@ -55,6 +66,18 @@ def test_prediction_result_df(precalculated_rf_regression_explainer):
5566
assert isinstance(df, pd.DataFrame)
5667

5768

69+
def test_get_col_value_plus_prediction_with_dataframe_predict(
70+
fitted_rf_regression_model, regression_data
71+
):
72+
_, _, X_test, y_test = regression_data
73+
wrapped_model = DataFramePredictWrapper(fitted_rf_regression_model)
74+
explainer = RegressionExplainer(wrapped_model, X_test.head(50), y_test.head(50))
75+
76+
_, prediction = explainer.get_col_value_plus_prediction("Age", index=0)
77+
78+
assert np.isscalar(prediction)
79+
80+
5881
def test_preds(precalculated_rf_regression_explainer):
5982
assert isinstance(precalculated_rf_regression_explainer.preds, np.ndarray)
6083

0 commit comments

Comments
 (0)