diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 3386f66..05361b7 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,6 +1,11 @@ # Release Notes +## Version 0.5.7: + +### Bug Fixes +- Allow FeatureInputComponent (what-if inputs) to customize numeric ranges and rounding, and apply min/max/step to inputs. + ## Version 0.5.6: ### Bug Fixes diff --git a/TODO.md b/TODO.md index 8328f8c..05e6fc8 100644 --- a/TODO.md +++ b/TODO.md @@ -6,10 +6,9 @@ - Rules: link an issue when possible; include size S/M/L; mark blockers. **Now** -- [S/M][Components][#277] whatif input range/rounding customization. +- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation. **Next** -- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation. - [S][Explainers][#270] Autogluon integration (coerce predict_proba to ndarray). - [M][Hub][#269] add_dashboard endpoint fails after first request (Flask blueprint lifecycle). - [M/L][Components][#262] add filters for random transaction selection in whatif tab. diff --git a/explainerdashboard/__init__.py b/explainerdashboard/__init__.py index 0204c0a..85e2349 100644 --- a/explainerdashboard/__init__.py +++ b/explainerdashboard/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.6" +__version__ = "0.5.7" import logging import sys diff --git a/explainerdashboard/dashboard_components/overview_components.py b/explainerdashboard/dashboard_components/overview_components.py index 1692f86..238cc6a 100644 --- a/explainerdashboard/dashboard_components/overview_components.py +++ b/explainerdashboard/dashboard_components/overview_components.py @@ -8,6 +8,7 @@ from math import ceil import numpy as np +import pandas as pd from pandas.api.types import is_bool_dtype from dash import html, dcc, Input, Output @@ -1091,6 +1092,8 @@ def __init__( n_input_cols=4, sort_features="shap", fill_row_first=True, + feature_input_ranges=None, + round=2, description=None, **kwargs, ): @@ -1116,6 +1119,9 @@ def __init__( is 'shap' to sort by mean absolute shap value. fill_row_first (bool, optional): if True most important features will be on top row, if False they will be in most left column. + feature_input_ranges (dict, optional): dict mapping feature names to + (min, max) numeric ranges for input fields. + round (int, optional): number of decimals to round numeric ranges to. description (str, optional): Tooltip to display when hover over component title. When None default text is shown. @@ -1131,6 +1137,8 @@ def __init__( explainer, name="feature-input-index-" + self.name, **kwargs ) self.index_name = "feature-input-index-" + self.name + self.feature_input_ranges = feature_input_ranges or {} + self.round = round self._feature_callback_inputs = [ Input("feature-input-" + feature + "-input-" + self.name, "value") @@ -1214,17 +1222,28 @@ def _generate_dash_input(self, col, onehot_cols, onehot_dict, cat_dict): ) else: col_values = self.explainer.X[col][lambda x: x != self.explainer.na_fill] - if is_bool_dtype(col_values): + if col in self.feature_input_ranges: + min_range, max_range = self.feature_input_ranges[col] + elif is_bool_dtype(col_values): min_range = int(col_values.min()) max_range = int(col_values.max()) else: - min_range = np.round(col_values.min(), 2) - max_range = np.round(col_values.max(), 2) + min_range = np.round(col_values.min(), self.round) + max_range = np.round(col_values.max(), self.round) + + if is_bool_dtype(col_values) or pd.api.types.is_integer_dtype(col_values): + step = 1 + else: + step = 10 ** (-self.round) return html.Div( [ dbc.Label(col), dbc.Input( - id="feature-input-" + col + "-input-" + self.name, type="number" + id="feature-input-" + col + "-input-" + self.name, + type="number", + min=min_range, + max=max_range, + step=step, ), dbc.FormText(f"Range: {min_range}-{max_range}") if not self.hide_range diff --git a/pyproject.toml b/pyproject.toml index f1042b9..9751d12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "explainerdashboard" -version = "0.5.6" +version = "0.5.7" description = "Quickly build Explainable AI dashboards that show the inner workings of so-called \"blackbox\" machine learning models." readme = "README.md" requires-python = ">=3.10" diff --git a/tests/test_feature_input_component.py b/tests/test_feature_input_component.py index d8b4d16..f113112 100644 --- a/tests/test_feature_input_component.py +++ b/tests/test_feature_input_component.py @@ -24,3 +24,30 @@ def test_feature_input_component_handles_bool_columns(classifier_data): layout = component.layout() assert layout is not None + + +def test_feature_input_component_respects_custom_range_and_rounding(classifier_data): + X_train, y_train, X_test, y_test = classifier_data + + model = RandomForestClassifier(n_estimators=5, max_depth=2) + model.fit(X_train, y_train) + + explainer = ClassifierExplainer(model, X_test, y_test) + component = FeatureInputComponent( + explainer, feature_input_ranges={"Age": (0, 50)}, round=1 + ) + + age_div = next( + div + for div in component._feature_inputs + if getattr(div.children[0], "children", None) == "Age" + ) + age_input = age_div.children[1] + range_text = age_div.children[2].children + + props = age_input.to_plotly_json()["props"] + + assert props.get("min") == 0 + assert props.get("max") == 50 + assert props.get("step") == 0.1 + assert range_text == "Range: 0-50"