Skip to content

Commit e095d9f

Browse files
kkollsgaclaude
andauthored
Fix silent data corruption when writing dask arrays to sharded zarr stores (#11117)
When writing dask-backed arrays to zarr with sharding enabled, the chunk alignment validation was checking against zarr's internal chunk size instead of the shard size. This allowed configurations where dask chunk boundaries didn't align with shard boundaries, causing data corruption during parallel writes. This fix uses shard boundaries (when shards are specified) for both grid_rechunk() and validate_grid_chunks_alignment(), ensuring that parallel writes don't cross shard boundaries. Fixes #10831 Co-authored-by: Claude <noreply@anthropic.com>
1 parent 519e5d1 commit e095d9f

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ Deprecations
2626
Bug Fixes
2727
~~~~~~~~~
2828

29+
- Fix silent data corruption when writing dask arrays to sharded Zarr stores.
30+
Dask chunk boundaries must now align with shard boundaries, not just internal
31+
Zarr chunk boundaries (:issue:`10831`).
32+
2933

3034
Documentation
3135
~~~~~~~~~~~~~

xarray/backends/zarr.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,16 +1240,22 @@ def set_variables(
12401240
zarr_format=3 if is_zarr_v3_format else 2,
12411241
)
12421242

1243-
if self._align_chunks and isinstance(encoding["chunks"], tuple):
1243+
# When shards are specified, dask chunks must align with shard boundaries
1244+
# (not just zarr chunk boundaries) to avoid data corruption during
1245+
# parallel writes. See https://github.com/pydata/xarray/issues/10831
1246+
effective_write_chunks = encoding.get("shards") or encoding["chunks"]
1247+
1248+
if self._align_chunks and isinstance(effective_write_chunks, tuple):
12441249
v = grid_rechunk(
12451250
v=v,
1246-
enc_chunks=encoding["chunks"],
1251+
enc_chunks=effective_write_chunks,
12471252
region=region,
12481253
)
12491254

1250-
if self._safe_chunks and isinstance(encoding["chunks"], tuple):
1255+
if self._safe_chunks and isinstance(effective_write_chunks, tuple):
12511256
# the hard case
12521257
# DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
1258+
# (or shard, when sharding is enabled)
12531259
# this avoids the need to get involved in zarr synchronization / locking
12541260
# From zarr docs:
12551261
# "If each worker in a parallel computation is writing to a
@@ -1260,7 +1266,7 @@ def set_variables(
12601266
shape = zarr_shape or v.shape
12611267
validate_grid_chunks_alignment(
12621268
nd_v_chunks=v.chunks,
1263-
enc_chunks=encoding["chunks"],
1269+
enc_chunks=effective_write_chunks,
12641270
region=region,
12651271
allow_partial_chunks=self._mode != "r+",
12661272
name=name,

xarray/tests/test_backends.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2869,6 +2869,41 @@ def test_shard_encoding(self) -> None:
28692869
with self.roundtrip(data) as actual:
28702870
pass
28712871

2872+
@requires_dask
2873+
def test_shard_encoding_with_dask(self) -> None:
2874+
# Test that dask chunks must align with shard boundaries.
2875+
# See https://github.com/pydata/xarray/issues/10831
2876+
if not (has_zarr_v3 and zarr.config.config["default_zarr_format"] == 3):
2877+
pytest.skip("sharding requires zarr v3 format")
2878+
2879+
ds = xr.DataArray(np.arange(12), dims="x", name="var1").to_dataset()
2880+
2881+
# Case 1: Dask chunks equal to shards should work
2882+
# (zarr chunk=3, shard=6, dask chunk=6)
2883+
ds1 = ds.chunk({"x": 6})
2884+
ds1["var1"].encoding = {"chunks": (3,), "shards": (6,)}
2885+
with self.roundtrip(ds1) as actual:
2886+
assert_identical(ds, actual)
2887+
2888+
# Case 2: Dask chunks that are multiples of shards should work
2889+
# (zarr chunk=1, shard=3, dask chunk=6)
2890+
ds2 = ds.chunk({"x": 6})
2891+
ds2["var1"].encoding = {"chunks": (1,), "shards": (3,)}
2892+
with self.roundtrip(ds2) as actual:
2893+
assert_identical(ds, actual)
2894+
2895+
# Case 3: Dask chunks smaller than shards should fail
2896+
# (zarr chunk=2, shard=4, dask chunk=3) - dask chunk doesn't align with shard
2897+
ds3 = ds.chunk({"x": 3})
2898+
ds3["var1"].encoding = {"chunks": (2,), "shards": (4,)}
2899+
with pytest.raises(ValueError, match=r"would overlap"):
2900+
with self.roundtrip(ds3) as actual:
2901+
pass
2902+
2903+
# Case 4: Can bypass with safe_chunks=False (but data may be corrupted)
2904+
with self.roundtrip(ds3, save_kwargs={"safe_chunks": False}) as actual:
2905+
pass
2906+
28722907
@requires_dask
28732908
@pytest.mark.skipif(
28742909
ON_WINDOWS,

0 commit comments

Comments
 (0)