@@ -281,3 +281,72 @@ def test_proximity_distance_against_qgis(raster, qgis_proximity_distance_target_
281281
282282 general_output_checks (input_raster , xrspatial_result )
283283 np .testing .assert_allclose (xrspatial_result .data , qgis_result .data , rtol = 1e-05 , equal_nan = True )
284+
285+
286+ @pytest .mark .skipif (da is None , reason = "dask is not installed" )
287+ def test_proximity_dask_coord_arrays_are_lazy ():
288+ """
289+ Test that coordinate arrays (xs, ys) are created as dask arrays
290+ when input is a dask array, avoiding memory issues with large rasters.
291+
292+ This is a regression test for the issue where xs and ys were created
293+ as numpy arrays before checking if the input was a dask array,
294+ causing memory issues for large datasets.
295+ """
296+ from unittest .mock import patch
297+
298+ height , width = 100 , 120
299+ data = np .zeros ((height , width ), dtype = np .float64 )
300+ # Add some target pixels
301+ data [10 , 10 ] = 1.0
302+ data [50 , 60 ] = 2.0
303+ data [90 , 100 ] = 3.0
304+
305+ _lon = np .linspace (- 180 , 180 , width )
306+ _lat = np .linspace (90 , - 90 , height )
307+ raster = xr .DataArray (data , dims = ['lat' , 'lon' ])
308+ raster ['lon' ] = _lon
309+ raster ['lat' ] = _lat
310+ # Create dask-backed array with chunks
311+ raster .data = da .from_array (data , chunks = (25 , 30 ))
312+
313+ # Track calls to np.tile and np.repeat with the full raster shape
314+ original_tile = np .tile
315+ original_repeat = np .repeat
316+ large_numpy_array_created = []
317+
318+ def tracking_tile (A , reps ):
319+ result = original_tile (A , reps )
320+ # Check if result would be the size of the full coordinate grid
321+ if result .size >= height * width :
322+ large_numpy_array_created .append (('tile' , result .shape ))
323+ return result
324+
325+ def tracking_repeat (a , repeats , axis = None ):
326+ result = original_repeat (a , repeats , axis = axis )
327+ # Check if result would be the size of the full coordinate grid
328+ if result .size >= height * width :
329+ large_numpy_array_created .append (('repeat' , result .shape ))
330+ return result
331+
332+ with patch .object (np , 'tile' , tracking_tile ):
333+ with patch .object (np , 'repeat' , tracking_repeat ):
334+ result = proximity (raster , x = 'lon' , y = 'lat' )
335+
336+ # Verify no large numpy coordinate arrays were created
337+ assert len (large_numpy_array_created ) == 0 , (
338+ f"Large numpy arrays were created for coordinates: { large_numpy_array_created } . "
339+ "For dask inputs, coordinate arrays should be created using dask operations."
340+ )
341+
342+ # Verify result is a dask array
343+ assert isinstance (result .data , da .Array ), "Result should be a dask array"
344+
345+ # Verify correctness by computing and checking a few values
346+ computed = result .compute ()
347+ # Check that target pixels have distance 0
348+ assert computed .data [10 , 10 ] == 0.0
349+ assert computed .data [50 , 60 ] == 0.0
350+ assert computed .data [90 , 100 ] == 0.0
351+ # Check that non-target pixels have positive distance
352+ assert computed .data [0 , 0 ] > 0.0
0 commit comments