Skip to content

Commit 0bd16b0

Browse files
committed
added test for dask array case in proximity
1 parent a27cb7c commit 0bd16b0

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

xrspatial/tests/test_proximity.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,72 @@ def test_proximity_distance_against_qgis(raster, qgis_proximity_distance_target_
281281

282282
general_output_checks(input_raster, xrspatial_result)
283283
np.testing.assert_allclose(xrspatial_result.data, qgis_result.data, rtol=1e-05, equal_nan=True)
284+
285+
286+
@pytest.mark.skipif(da is None, reason="dask is not installed")
287+
def test_proximity_dask_coord_arrays_are_lazy():
288+
"""
289+
Test that coordinate arrays (xs, ys) are created as dask arrays
290+
when input is a dask array, avoiding memory issues with large rasters.
291+
292+
This is a regression test for the issue where xs and ys were created
293+
as numpy arrays before checking if the input was a dask array,
294+
causing memory issues for large datasets.
295+
"""
296+
from unittest.mock import patch
297+
298+
height, width = 100, 120
299+
data = np.zeros((height, width), dtype=np.float64)
300+
# Add some target pixels
301+
data[10, 10] = 1.0
302+
data[50, 60] = 2.0
303+
data[90, 100] = 3.0
304+
305+
_lon = np.linspace(-180, 180, width)
306+
_lat = np.linspace(90, -90, height)
307+
raster = xr.DataArray(data, dims=['lat', 'lon'])
308+
raster['lon'] = _lon
309+
raster['lat'] = _lat
310+
# Create dask-backed array with chunks
311+
raster.data = da.from_array(data, chunks=(25, 30))
312+
313+
# Track calls to np.tile and np.repeat with the full raster shape
314+
original_tile = np.tile
315+
original_repeat = np.repeat
316+
large_numpy_array_created = []
317+
318+
def tracking_tile(A, reps):
319+
result = original_tile(A, reps)
320+
# Check if result would be the size of the full coordinate grid
321+
if result.size >= height * width:
322+
large_numpy_array_created.append(('tile', result.shape))
323+
return result
324+
325+
def tracking_repeat(a, repeats, axis=None):
326+
result = original_repeat(a, repeats, axis=axis)
327+
# Check if result would be the size of the full coordinate grid
328+
if result.size >= height * width:
329+
large_numpy_array_created.append(('repeat', result.shape))
330+
return result
331+
332+
with patch.object(np, 'tile', tracking_tile):
333+
with patch.object(np, 'repeat', tracking_repeat):
334+
result = proximity(raster, x='lon', y='lat')
335+
336+
# Verify no large numpy coordinate arrays were created
337+
assert len(large_numpy_array_created) == 0, (
338+
f"Large numpy arrays were created for coordinates: {large_numpy_array_created}. "
339+
"For dask inputs, coordinate arrays should be created using dask operations."
340+
)
341+
342+
# Verify result is a dask array
343+
assert isinstance(result.data, da.Array), "Result should be a dask array"
344+
345+
# Verify correctness by computing and checking a few values
346+
computed = result.compute()
347+
# Check that target pixels have distance 0
348+
assert computed.data[10, 10] == 0.0
349+
assert computed.data[50, 60] == 0.0
350+
assert computed.data[90, 100] == 0.0
351+
# Check that non-target pixels have positive distance
352+
assert computed.data[0, 0] > 0.0

0 commit comments

Comments
 (0)