Skip to content

Commit 364361f

Browse files
iampelle0verhead
andauthored
Release 4.0.2 (#159)
* Fixed categorical_order_by used with array_like (#157) * Fix categorical_order_by check for scatter plot * Fix categorical_order_by check for _construct_source * Refactor category sorting in _construct_source * Add tests for categorical_order_by * Correct scatter test (#158) * Update version in init * Update HISTORY.rst --------- Co-authored-by: Quoc Duong Bui <35042166+vanHekthor@users.noreply.github.com>
1 parent ca828ab commit 364361f

File tree

4 files changed

+127
-36
lines changed

4 files changed

+127
-36
lines changed

HISTORY.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
History
33
=======
44

5+
4.0.2 (2023-03-30)
6+
------------------
7+
8+
* Fix categorical_order_by check for scatter plot
9+
* Fix categorical_order_by check for _construct_source
10+
* Refactor category sorting in _construct_source
11+
* Add tests for categorical_order_by
12+
* Fix scatter plot tests that used line plots
13+
514
4.0.1 (2023-03-24)
615
------------------
716

chartify/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
__author__ = """Chris Halpert"""
2525
__email__ = "chalpert@spotify.com"
26-
__version__ = "4.0.1"
26+
__version__ = "4.0.2"
2727

2828
_IPYTHON_INSTANCE = False
2929

chartify/_core/plot.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,49 @@ def _get_bar_width(factors):
947947
else:
948948
return 0.9
949949

950+
@staticmethod
951+
def _sort_categories_by_value(source, categorical_columns, categorical_order_ascending):
952+
# Recursively sort values within each level of the index.
953+
row_totals = source.sum(axis=1, numeric_only=True)
954+
row_totals.name = "sum"
955+
old_index = row_totals.index
956+
row_totals = row_totals.reset_index()
957+
row_totals.columns = ["_%s" % col for col in row_totals.columns]
958+
row_totals.index = old_index
959+
960+
hierarchical_sort_cols = categorical_columns[:]
961+
for i, _ in enumerate(hierarchical_sort_cols):
962+
row_totals["level_%s" % i] = row_totals.groupby(hierarchical_sort_cols[: i + 1])["_sum"].transform(
963+
"sum"
964+
)
965+
row_totals = row_totals.sort_values(
966+
by=["level_%s" % i for i, _ in enumerate(hierarchical_sort_cols)],
967+
ascending=categorical_order_ascending,
968+
)
969+
return source.reindex(row_totals.index)
970+
971+
@staticmethod
972+
def _sort_categories(
973+
source,
974+
categorical_columns,
975+
categorical_order_by,
976+
categorical_order_ascending
977+
):
978+
979+
is_string = isinstance(categorical_order_by, str)
980+
order_length = getattr(categorical_order_by, "__len__", None)
981+
# Sort the categories
982+
if is_string and categorical_order_by == "values":
983+
return PlotMixedTypeXY._sort_categories_by_value(
984+
source, categorical_columns, categorical_order_ascending)
985+
elif is_string and categorical_order_by == "labels":
986+
return source.sort_index(axis=0, ascending=categorical_order_ascending)
987+
# Manual sort
988+
elif not is_string and order_length is not None:
989+
return source.reindex(categorical_order_by, axis="index")
990+
991+
raise ValueError("""Must be 'values', 'labels', or a list of values.""")
992+
950993
def _construct_source(
951994
self,
952995
data_frame,
@@ -1014,34 +1057,7 @@ def _construct_source(
10141057
if normalize:
10151058
source = source.div(source.sum(axis=1), axis=0)
10161059

1017-
order_length = getattr(categorical_order_by, "__len__", None)
1018-
# Sort the categories
1019-
if categorical_order_by == "values":
1020-
# Recursively sort values within each level of the index.
1021-
row_totals = source.sum(axis=1, numeric_only=True)
1022-
row_totals.name = "sum"
1023-
old_index = row_totals.index
1024-
row_totals = row_totals.reset_index()
1025-
row_totals.columns = ["_%s" % col for col in row_totals.columns]
1026-
row_totals.index = old_index
1027-
1028-
heirarchical_sort_cols = categorical_columns[:]
1029-
for i, _ in enumerate(heirarchical_sort_cols):
1030-
row_totals["level_%s" % i] = row_totals.groupby(heirarchical_sort_cols[: i + 1])["_sum"].transform(
1031-
"sum"
1032-
)
1033-
row_totals = row_totals.sort_values(
1034-
by=["level_%s" % i for i, _ in enumerate(heirarchical_sort_cols)],
1035-
ascending=categorical_order_ascending,
1036-
)
1037-
source = source.reindex(row_totals.index)
1038-
elif categorical_order_by == "labels":
1039-
source = source.sort_index(axis=0, ascending=categorical_order_ascending)
1040-
# Manual sort
1041-
elif order_length is not None:
1042-
source = source.reindex(categorical_order_by, axis="index")
1043-
else:
1044-
raise ValueError("""Must be 'values', 'labels', or a list of values.""")
1060+
source = self._sort_categories(source, categorical_columns, categorical_order_by, categorical_order_ascending)
10451061

10461062
# Cast all categorical columns to strings
10471063
# Plotting functions will break with non-str types.
@@ -2003,13 +2019,14 @@ def scatter(
20032019

20042020
axis_factors = data_frame.groupby(categorical_columns).size()
20052021

2022+
is_string = isinstance(categorical_order_by, str)
20062023
order_length = getattr(categorical_order_by, "__len__", None)
2007-
if categorical_order_by == "labels":
2024+
if is_string and categorical_order_by == "labels":
20082025
axis_factors = axis_factors.sort_index(ascending=categorical_order_ascending).index
2009-
elif categorical_order_by == "count":
2026+
elif is_string and categorical_order_by == "count":
20102027
axis_factors = axis_factors.sort_values(ascending=categorical_order_ascending).index
20112028
# User-specified order.
2012-
elif order_length is not None:
2029+
elif not is_string and order_length is not None:
20132030
axis_factors = categorical_order_by
20142031
else:
20152032
raise ValueError("""Must be 'count', 'labels', or a list of values.""")

tests/test_plots.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,15 @@ def setup_method(self):
130130
})
131131

132132
def test_single_numeric_scatter(self):
133-
"""Single line test"""
133+
"""Single scatter test"""
134134
single_scatter = self.data[self.data['category1'] == 'a']
135135
ch = chartify.Chart()
136-
ch.plot.line(single_scatter, x_column='number1', y_column='number2')
136+
ch.plot.scatter(single_scatter, x_column='number1', y_column='number2')
137137
assert (np.array_equal(chart_data(ch, '')['number1'], [1., 2., 3.]))
138138
assert (np.array_equal(chart_data(ch, '')['number2'], [5, 10, 0]))
139139

140140
def test_multi_numeric_scatter(self):
141-
"""Single line test"""
141+
"""Multi scatter test"""
142142
ch = chartify.Chart()
143143
ch.plot.scatter(
144144
self.data,
@@ -151,7 +151,7 @@ def test_multi_numeric_scatter(self):
151151
assert (np.array_equal(chart_data(ch, 'b')['number2'], [4, -3, -10]))
152152

153153
def test_single_datetime_scatter(self):
154-
"""Single line test"""
154+
"""Single datetime scatter test"""
155155
data = pd.DataFrame({
156156
'number': [1, 10, -10, 0],
157157
'datetimes':
@@ -794,6 +794,71 @@ def test_grouped_histogram(self):
794794
assert (np.array_equal(chart_data(ch, 'b')['min_edge'], [2., 6.]))
795795

796796

797+
class TestCategoricalOrderBy:
798+
def _assert_order_by_array_like(self, chart):
799+
assert (np.array_equal(chart.figure.x_range.factors, ['b', 'd', 'a', 'c']))
800+
# check bar data
801+
assert (np.array_equal(chart_data(chart, '')['factors'], ['b', 'd', 'a', 'c']))
802+
assert (np.array_equal(chart_data(chart, '')['number1'], [3, 1, 4, 2]))
803+
# check scatter data
804+
assert (np.array_equal(chart_data(chart, 'number1')['factors'], ['a', 'b', 'c', 'd']))
805+
assert (np.array_equal(chart_data(chart, 'number1')['number1'], [4, 3, 2, 1]))
806+
807+
def setup_method(self):
808+
self.data1 = pd.DataFrame({
809+
'category1': ['a', 'b', 'c', 'd'],
810+
'number1': [4, 3, 2, 1],
811+
})
812+
813+
self.data2 = pd.DataFrame({
814+
'category2': ['b', 'a', 'b', 'b', 'a', 'c'],
815+
'number2': [1, 2, 3, 4, 5, 6]
816+
})
817+
818+
def test_order_by_labels(self):
819+
ch = chartify.Chart(x_axis_type='categorical')
820+
821+
ch.plot.bar(self.data1, ['category1'], 'number1', categorical_order_by='labels')
822+
assert (np.array_equal(ch.figure.x_range.factors, ['d', 'c', 'b', 'a']))
823+
assert (np.array_equal(chart_data(ch, '')['factors'], ['d', 'c', 'b', 'a']))
824+
assert (np.array_equal(chart_data(ch, '')['number1'], [1, 2, 3, 4]))
825+
826+
ch.plot.scatter(self.data1, ['category1'], 'number1', categorical_order_by='labels')
827+
assert (np.array_equal(ch.figure.x_range.factors, ['d', 'c', 'b', 'a']))
828+
assert (np.array_equal(chart_data(ch, 'number1')['factors'], ['a', 'b', 'c', 'd']))
829+
assert (np.array_equal(chart_data(ch, 'number1')['number1'], [4, 3, 2, 1]))
830+
831+
def test_order_by_values(self):
832+
ch = chartify.Chart(x_axis_type='categorical')
833+
ch.plot.bar(self.data1, ['category1'], 'number1', categorical_order_by='values')
834+
assert (np.array_equal(chart_data(ch, '')['factors'], ['a', 'b', 'c', 'd']))
835+
assert (np.array_equal(chart_data(ch, '')['number1'], [4, 3, 2, 1]))
836+
837+
def test_order_by_count(self):
838+
ch = chartify.Chart(x_axis_type='categorical')
839+
ch.plot.scatter(self.data2, ['category2'], 'number2', categorical_order_by='count')
840+
841+
assert (np.array_equal(ch.figure.x_range.factors, ['b', 'a', 'c']))
842+
assert (np.array_equal(chart_data(ch, 'number2')['factors'], ['b', 'a', 'b', 'b', 'a', 'c']))
843+
assert (np.array_equal(chart_data(ch, 'number2')['number2'], [1, 2, 3, 4, 5, 6]))
844+
845+
@pytest.mark.parametrize(
846+
'array_like', [['b', 'd', 'a', 'c'], np.array(['b', 'd', 'a', 'c']), pd.Series(['b', 'd', 'a', 'c'])])
847+
def test_order_by_array_like(self, array_like):
848+
ch = chartify.Chart(x_axis_type='categorical')
849+
ch.plot.scatter(self.data1, ['category1'], 'number1', categorical_order_by=array_like)
850+
ch.plot.bar(self.data1, ['category1'], 'number1', categorical_order_by=array_like)
851+
852+
self._assert_order_by_array_like(ch)
853+
854+
@pytest.mark.parametrize('plot_method,categorical_order_by', [('bar', 'count'), ('scatter', 'values')])
855+
def test_error(self, plot_method, categorical_order_by):
856+
ch = chartify.Chart(x_axis_type='categorical', y_axis_type='linear')
857+
with pytest.raises(ValueError):
858+
plot_method = getattr(ch.plot, plot_method)
859+
plot_method(self.data1, ['category1'], 'number1', categorical_order_by=categorical_order_by)
860+
861+
797862
def test_categorical_axis_type_casting():
798863
"""Categorical axis plotting breaks for non-str types.
799864
Test that type casting is performed correctly"""

0 commit comments

Comments
 (0)