|
2 | 2 | import pandas as pd |
3 | 3 |
|
4 | 4 | from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin |
| 5 | +from sklearn.metrics import roc_auc_score |
5 | 6 | from sklearn.metrics import r2_score |
6 | 7 |
|
7 | 8 | from explainerdashboard.explainer_methods import ( |
8 | 9 | align_categorical_dtypes, |
9 | 10 | get_pdp_df, |
| 11 | + make_one_vs_all_scorer, |
10 | 12 | permutation_importances, |
11 | 13 | ) |
12 | 14 |
|
@@ -37,6 +39,12 @@ def predict_proba(self, X): |
37 | 39 | return np.tile(np.array([0.2, 0.8]), (len(X), 1)) |
38 | 40 |
|
39 | 41 |
|
| 42 | +class DataFrameProbaClassifier(ClassifierMixin, BaseEstimator): |
| 43 | + def predict_proba(self, X): |
| 44 | + probs = np.tile(np.array([0.2, 0.8]), (len(X), 1)) |
| 45 | + return pd.DataFrame(probs, index=getattr(X, "index", None), columns=[0, 1]) |
| 46 | + |
| 47 | + |
40 | 48 | def test_permutation_importances_preserves_categorical_dtypes(): |
41 | 49 | X = pd.DataFrame( |
42 | 50 | { |
@@ -80,3 +88,14 @@ def test_align_categorical_dtypes_matches_reference(): |
80 | 88 | aligned = align_categorical_dtypes(target, reference) |
81 | 89 |
|
82 | 90 | assert aligned["cat"].dtype == reference["cat"].dtype |
| 91 | + |
| 92 | + |
| 93 | +def test_make_one_vs_all_scorer_accepts_dataframe_predict_proba(): |
| 94 | + scorer = make_one_vs_all_scorer(roc_auc_score, pos_label=1) |
| 95 | + clf = DataFrameProbaClassifier() |
| 96 | + X = pd.DataFrame({"feature": [0, 1, 2, 3]}) |
| 97 | + y = np.array([0, 1, 0, 1]) |
| 98 | + |
| 99 | + score = scorer(clf, X, y) |
| 100 | + |
| 101 | + assert isinstance(score, float) |
0 commit comments