From 6aeb9941138e1031c2d7200f4fac2e2fe183f03d Mon Sep 17 00:00:00 2001 From: kkollsga Date: Sat, 31 Jan 2026 22:21:29 +0100 Subject: [PATCH] Fix sortby descending order placing NaNs at beginning instead of end Use duck_array_ops.notnull as additional sort keys to ensure null values sort to the end in descending order. This is cleaner than the previous approach of manually tracking NaN positions. Fixes #7358 Co-Authored-By: Claude --- doc/whats-new.rst | 3 +++ xarray/core/dataset.py | 17 +++++++++++++-- xarray/tests/test_dataset.py | 41 ++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 83fd0861408..e91c9da6997 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,9 @@ Bug Fixes Dask chunk boundaries must now align with shard boundaries, not just internal Zarr chunk boundaries (:issue:`10831`). +- Fix :py:meth:`Dataset.sortby` and :py:meth:`DataArray.sortby` placing NaN values + at the beginning instead of the end when using ``ascending=False`` (:issue:`7358`). + By `Kristian KollsgÄrd `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 425a3dc19cb..e22c080f159 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8129,8 +8129,21 @@ def sortby( indices = {} for key, arrays in vars_by_dim.items(): - order = np.lexsort(tuple(reversed(arrays))) - indices[key] = order if ascending else order[::-1] + if ascending: + indices[key] = np.lexsort(tuple(reversed(arrays))) + else: + # For descending order, we need to keep NaNs at the end. + # By adding notnull(arr) as additional sort keys, null values + # sort to the beginning (False=0 < True=1), then reversing + # puts them at the end. See https://github.com/pydata/xarray/issues/7358 + indices[key] = np.lexsort( + tuple( + [ + *reversed(arrays), + *[duck_array_ops.notnull(arr) for arr in reversed(arrays)], + ] + ) + )[::-1] return aligned_self.isel(indices) def quantile( diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2848ae5a7be..ec68cf1fc02 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7236,6 +7236,47 @@ def test_sortby(self) -> None: actual = ds.sortby(["x", "y"], ascending=False) assert_equal(actual, ds) + def test_sortby_descending_nans(self) -> None: + # Regression test for https://github.com/pydata/xarray/issues/7358 + # NaN values should remain at the end when sorting in descending order + ds = Dataset({"var": ("x", [3.0, np.nan, 4.0, 2.0, np.nan])}) + + # Ascending: NaNs at end + result_asc = ds.sortby("var", ascending=True) + assert_array_equal(result_asc["var"].values[:3], [2.0, 3.0, 4.0]) + assert np.all(np.isnan(result_asc["var"].values[3:])) + + # Descending: NaNs should also be at end (not beginning) + result_desc = ds.sortby("var", ascending=False) + assert_array_equal(result_desc["var"].values[:3], [4.0, 3.0, 2.0]) + assert np.all(np.isnan(result_desc["var"].values[3:])) + + def test_sortby_descending_nans_multi_key(self) -> None: + # Test sortby with multiple keys where one has NaN values + # Regression test for https://github.com/pydata/xarray/issues/7358 + ds = Dataset( + { + "A": (("x", "y"), [[1, 2, 3], [4, 5, 6]]), + "B": (("x", "y"), [[7, 8, 9], [10, 11, 12]]), + }, + coords={"x": ["b", "a"], "y": [np.nan, 1, 0]}, + ) + + # Sort by multiple keys in descending order + result = ds.sortby(["x", "y"], ascending=False) + + # x should be sorted descending: ["b", "a"] + assert_array_equal(result["x"].values, ["b", "a"]) + + # y should be sorted descending with NaN at end: [1, 0, nan] + assert_array_equal(result["y"].values[:2], [1, 0]) + assert np.isnan(result["y"].values[2]) + + # Verify data is reordered correctly + # Original y=[nan, 1, 0] -> sorted y=[1, 0, nan] means columns reordered [1, 2, 0] + assert_array_equal(result["A"].values, [[2, 3, 1], [5, 6, 4]]) + assert_array_equal(result["B"].values, [[8, 9, 7], [11, 12, 10]]) + def test_attribute_access(self) -> None: ds = create_test_data(seed=1) for key in ["var1", "var2", "var3", "time", "dim1", "dim2", "dim3", "numbers"]: