Skip to content

Commit c429fc0

Browse files
strengthen xlim and max_num_features tests
1 parent 51c13ce commit c429fc0

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

tests/python_package_test/test_plotting.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,18 @@ def test_plot_importance(params, breast_cancer_split, train_data):
122122
lgb.plot_importance(gbm0, title=None, xlabel=None, ylabel=None, figsize="not a tuple")
123123

124124
# test max_num_features parameter
125+
total_features = len(gbm0.feature_importance())
126+
assert total_features > 5, "model must have more than 5 features to test max_num_features"
125127
ax7 = lgb.plot_importance(gbm0, max_num_features=5)
126128
assert isinstance(ax7, matplotlib.axes.Axes)
127129
assert len(ax7.patches) == 5
130+
# verify the 5 displayed features are the top 5 by importance
131+
importance = gbm0.feature_importance()
132+
feature_names = gbm0.feature_name()
133+
sorted_pairs = sorted(zip(feature_names, importance), key=lambda x: x[1])
134+
top5_names = [name for name, _ in sorted_pairs[-5:]]
135+
displayed_labels = [label.get_text() for label in ax7.get_yticklabels()]
136+
assert displayed_labels == top5_names
128137

129138
gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, verbose=-1, importance_type="gain")
130139
gbm2.fit(X_train, y_train)
@@ -193,10 +202,16 @@ def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
193202
assert ax2.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # b
194203

195204
# test xlim parameter
196-
ax3 = lgb.plot_split_value_histogram(gbm0, 27, xlim=(0, 100))
205+
ax3 = lgb.plot_split_value_histogram(gbm0, 27, xlim=(0, 100), title=None, xlabel=None, ylabel=None)
197206
assert isinstance(ax3, matplotlib.axes.Axes)
207+
assert ax3.get_title() == ""
208+
assert ax3.get_xlabel() == ""
209+
assert ax3.get_ylabel() == ""
198210
assert ax3.get_xlim() == (0, 100)
199211

212+
with pytest.raises(TypeError, match="xlim must be a tuple of 2 elements."):
213+
lgb.plot_split_value_histogram(gbm0, 27, xlim="not a tuple")
214+
200215
with pytest.raises(
201216
ValueError, match="Cannot plot split value histogram, because feature 0 was not used in splitting"
202217
):
@@ -569,6 +584,12 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
569584
assert legend_items[0].get_text() == "valid_0"
570585

571586
# test xlim parameter
572-
ax5 = lgb.plot_metric(evals_result0, metric="binary_logloss", xlim=(0, 15))
587+
ax5 = lgb.plot_metric(evals_result0, metric="binary_logloss", xlim=(0, 15), title=None, xlabel=None, ylabel=None)
573588
assert isinstance(ax5, matplotlib.axes.Axes)
589+
assert ax5.get_title() == ""
590+
assert ax5.get_xlabel() == ""
591+
assert ax5.get_ylabel() == ""
574592
assert ax5.get_xlim() == (0, 15)
593+
594+
with pytest.raises(TypeError, match="xlim must be a tuple of 2 elements."):
595+
lgb.plot_metric(evals_result0, metric="binary_logloss", xlim="not a tuple")

0 commit comments

Comments
 (0)