Skip to content

Commit 6aeb994

Browse files
kkollsgaclaude
andcommitted
Fix sortby descending order placing NaNs at beginning instead of end
Use duck_array_ops.notnull as additional sort keys to ensure null values sort to the end in descending order. This is cleaner than the previous approach of manually tracking NaN positions. Fixes #7358 Co-Authored-By: Claude <noreply@anthropic.com>
1 parent e095d9f commit 6aeb994

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ Bug Fixes
3030
Dask chunk boundaries must now align with shard boundaries, not just internal
3131
Zarr chunk boundaries (:issue:`10831`).
3232

33+
- Fix :py:meth:`Dataset.sortby` and :py:meth:`DataArray.sortby` placing NaN values
34+
at the beginning instead of the end when using ``ascending=False`` (:issue:`7358`).
35+
By `Kristian Kollsgård <https://github.com/kkollsga>`_.
3336

3437
Documentation
3538
~~~~~~~~~~~~~

xarray/core/dataset.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8129,8 +8129,21 @@ def sortby(
81298129

81308130
indices = {}
81318131
for key, arrays in vars_by_dim.items():
8132-
order = np.lexsort(tuple(reversed(arrays)))
8133-
indices[key] = order if ascending else order[::-1]
8132+
if ascending:
8133+
indices[key] = np.lexsort(tuple(reversed(arrays)))
8134+
else:
8135+
# For descending order, we need to keep NaNs at the end.
8136+
# By adding notnull(arr) as additional sort keys, null values
8137+
# sort to the beginning (False=0 < True=1), then reversing
8138+
# puts them at the end. See https://github.com/pydata/xarray/issues/7358
8139+
indices[key] = np.lexsort(
8140+
tuple(
8141+
[
8142+
*reversed(arrays),
8143+
*[duck_array_ops.notnull(arr) for arr in reversed(arrays)],
8144+
]
8145+
)
8146+
)[::-1]
81348147
return aligned_self.isel(indices)
81358148

81368149
def quantile(

xarray/tests/test_dataset.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7236,6 +7236,47 @@ def test_sortby(self) -> None:
72367236
actual = ds.sortby(["x", "y"], ascending=False)
72377237
assert_equal(actual, ds)
72387238

7239+
def test_sortby_descending_nans(self) -> None:
7240+
# Regression test for https://github.com/pydata/xarray/issues/7358
7241+
# NaN values should remain at the end when sorting in descending order
7242+
ds = Dataset({"var": ("x", [3.0, np.nan, 4.0, 2.0, np.nan])})
7243+
7244+
# Ascending: NaNs at end
7245+
result_asc = ds.sortby("var", ascending=True)
7246+
assert_array_equal(result_asc["var"].values[:3], [2.0, 3.0, 4.0])
7247+
assert np.all(np.isnan(result_asc["var"].values[3:]))
7248+
7249+
# Descending: NaNs should also be at end (not beginning)
7250+
result_desc = ds.sortby("var", ascending=False)
7251+
assert_array_equal(result_desc["var"].values[:3], [4.0, 3.0, 2.0])
7252+
assert np.all(np.isnan(result_desc["var"].values[3:]))
7253+
7254+
def test_sortby_descending_nans_multi_key(self) -> None:
7255+
# Test sortby with multiple keys where one has NaN values
7256+
# Regression test for https://github.com/pydata/xarray/issues/7358
7257+
ds = Dataset(
7258+
{
7259+
"A": (("x", "y"), [[1, 2, 3], [4, 5, 6]]),
7260+
"B": (("x", "y"), [[7, 8, 9], [10, 11, 12]]),
7261+
},
7262+
coords={"x": ["b", "a"], "y": [np.nan, 1, 0]},
7263+
)
7264+
7265+
# Sort by multiple keys in descending order
7266+
result = ds.sortby(["x", "y"], ascending=False)
7267+
7268+
# x should be sorted descending: ["b", "a"]
7269+
assert_array_equal(result["x"].values, ["b", "a"])
7270+
7271+
# y should be sorted descending with NaN at end: [1, 0, nan]
7272+
assert_array_equal(result["y"].values[:2], [1, 0])
7273+
assert np.isnan(result["y"].values[2])
7274+
7275+
# Verify data is reordered correctly
7276+
# Original y=[nan, 1, 0] -> sorted y=[1, 0, nan] means columns reordered [1, 2, 0]
7277+
assert_array_equal(result["A"].values, [[2, 3, 1], [5, 6, 4]])
7278+
assert_array_equal(result["B"].values, [[8, 9, 7], [11, 12, 10]])
7279+
72397280
def test_attribute_access(self) -> None:
72407281
ds = create_test_data(seed=1)
72417282
for key in ["var1", "var2", "var3", "time", "dim1", "dim2", "dim3", "numbers"]:

0 commit comments

Comments
 (0)