Skip to content

Commit 803a66e

Browse files
Fix CI (dask pinning, test with rasterize, mypy) (#809)
passing tests and pre-commits
1 parent 02bc276 commit 803a66e

File tree

13 files changed

+108
-50
lines changed

13 files changed

+108
-50
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dependencies = [
2525
"anndata>=0.9.1",
2626
"click",
2727
"dask-image",
28-
"dask>=2024.4.1",
28+
"dask>=2024.4.1,<=2024.11.2",
2929
"fsspec",
3030
"geopandas>=0.14",
3131
"multiscale_spatial_image>=2.0.2",

src/spatialdata/_core/operations/rasterize.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ def rasterize(
217217
The table optionally containing the `value_key` and the name of the table in the returned `SpatialData` object.
218218
Must be `None` when `data` is a `SpatialData` object, otherwise it assumes the default value of `'table'`.
219219
return_regions_as_labels
220-
By default, single-scale images of shape `(c, y, x)` are returned. If `True`, returns labels and shapes as
221-
labels of shape `(y, x)` as opposed to an image of shape `(c, y, x)`. Points and images are always returned
222-
as images, and multiscale raster data is always returned as single-scale data.
220+
By default, single-scale images of shape `(c, y, x)` are returned. If `True`, returns labels, shapes and points
221+
as labels of shape `(y, x)` as opposed to an image of shape `(c, y, x)`. Images are always returned as images,
222+
and multiscale raster data is always returned as single-scale data.
223223
agg_func
224224
Available only when rasterizing points and shapes. A reduction function from datashader (its name, or a
225225
`Callable`). See the notes for more details on the default behavior.
@@ -234,6 +234,11 @@ def rasterize(
234234
into a `DataArray` (not a `DataTree`). So if a `SpatialData` object with elements is passed, a `SpatialData` object
235235
with single-scale images and labels will be returned.
236236
237+
When `return_regions_as_labels` is `True`, the returned `DataArray` object will have an attribute called
238+
`label_index_to_category` that maps the label index to the category name. You can access it via
239+
`returned_data.attrs["label_index_to_category"]`. The returned labels will start from 1 (0 is reserved for the
240+
background), and will be contiguous.
241+
237242
Notes
238243
-----
239244
For images and labels, the parameters `value_key`, `table_name`, `agg_func`, and `return_single_channel` are not
@@ -587,7 +592,7 @@ def rasterize_images_labels(
587592
)
588593
assert isinstance(transformed_dask, DaskArray)
589594
channels = xdata.coords["c"].values if schema in (Image2DModel, Image3DModel) else None
590-
transformed_data = schema.parse(transformed_dask, dims=xdata.dims, c_coords=channels) # type: ignore[call-arg,arg-type]
595+
transformed_data = schema.parse(transformed_dask, dims=xdata.dims, c_coords=channels) # type: ignore[call-arg]
591596

592597
if target_coordinate_system != "global":
593598
remove_transformation(transformed_data, "global")
@@ -650,7 +655,7 @@ def rasterize_shapes_points(
650655
if value_key is not None:
651656
kwargs = {"sdata": sdata, "element_name": element_name} if element_name is not None else {"element": data}
652657
data[VALUES_COLUMN] = get_values(value_key, table_name=table_name, **kwargs).iloc[:, 0] # type: ignore[arg-type, union-attr]
653-
elif isinstance(data, GeoDataFrame):
658+
elif isinstance(data, GeoDataFrame) or isinstance(data, DaskDataFrame) and return_regions_as_labels is True:
654659
value_key = VALUES_COLUMN
655660
data[VALUES_COLUMN] = data.index.astype("category")
656661
else:
@@ -706,6 +711,14 @@ def rasterize_shapes_points(
706711
agg = agg.fillna(0)
707712

708713
if return_regions_as_labels:
714+
if label_index_to_category is not None:
715+
max_label = next(iter(reversed(label_index_to_category.keys())))
716+
else:
717+
max_label = int(agg.max().values)
718+
max_uint16 = np.iinfo(np.uint16).max
719+
if max_label > max_uint16:
720+
raise ValueError(f"Maximum label index is {max_label}. Values higher than {max_uint16} are not supported.")
721+
agg = agg.astype(np.uint16)
709722
return Labels2DModel.parse(agg, transformations=transformations)
710723

711724
agg = agg.expand_dims(dim={"c": 1}).transpose("c", "y", "x")

src/spatialdata/_core/operations/transform.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _transform_raster(
5252
c_shape: tuple[int, ...]
5353
c_shape = (data.shape[0],) if "c" in axes else ()
5454
new_spatial_shape = tuple(
55-
int(np.max(new_v[:, i]) - np.min(new_v[:, i])) for i in range(len(c_shape), n_spatial_dims + len(c_shape)) # type: ignore[operator]
55+
int(np.max(new_v[:, i]) - np.min(new_v[:, i])) for i in range(len(c_shape), n_spatial_dims + len(c_shape))
5656
)
5757
output_shape = c_shape + new_spatial_shape
5858
translation_vector = np.min(new_v[:, :-1], axis=0)
@@ -86,8 +86,8 @@ def _transform_raster(
8686
# min_y_inverse = np.min(new_v_inverse[:, 1])
8787

8888
if "c" in axes:
89-
plt.imshow(da.moveaxis(transformed_dask, 0, 2), origin="lower", alpha=0.5) # type: ignore[attr-defined]
90-
plt.imshow(da.moveaxis(im, 0, 2), origin="lower", alpha=0.5) # type: ignore[attr-defined]
89+
plt.imshow(da.moveaxis(transformed_dask, 0, 2), origin="lower", alpha=0.5)
90+
plt.imshow(da.moveaxis(im, 0, 2), origin="lower", alpha=0.5)
9191
else:
9292
plt.imshow(transformed_dask, origin="lower", alpha=0.5)
9393
plt.imshow(im, origin="lower", alpha=0.5)
@@ -322,7 +322,7 @@ def _(
322322
)
323323
c_coords = data.indexes["c"].values if "c" in data.indexes else None
324324
# mypy thinks that schema could be ShapesModel, PointsModel, ...
325-
transformed_data = schema.parse(transformed_dask, dims=axes, c_coords=c_coords) # type: ignore[call-arg,arg-type]
325+
transformed_data = schema.parse(transformed_dask, dims=axes, c_coords=c_coords) # type: ignore[call-arg]
326326
assert isinstance(transformed_data, DataArray)
327327
old_transformations = get_transformation(data, get_all=True)
328328
assert isinstance(old_transformations, dict)
@@ -448,7 +448,7 @@ def _(
448448
for ax in axes:
449449
indices = xtransformed["dim"] == ax
450450
new_ax = xtransformed[:, indices]
451-
transformed[ax] = new_ax.data.flatten() # type: ignore[attr-defined]
451+
transformed[ax] = new_ax.data.flatten()
452452

453453
old_transformations = get_transformation(data, get_all=True)
454454
assert isinstance(old_transformations, dict)
@@ -481,9 +481,9 @@ def _(
481481
)
482482
# TODO: nitpick, mypy expects a listof literals and here we have a list of strings.
483483
# I ignored but we may want to fix this
484-
affine = transformation.to_affine(axes, axes) # type: ignore[arg-type]
484+
affine = transformation.to_affine(axes, axes)
485485
matrix = affine.matrix
486-
shapely_notation = matrix[:-1, :-1].ravel().tolist() + matrix[:-1, -1].tolist()
486+
shapely_notation = matrix[:-1, :-1].ravel().tolist() + matrix[:-1, -1].tolist() # type: ignore[operator]
487487
transformed_geometry = data.geometry.affine_transform(shapely_notation)
488488
transformed_data = data.copy(deep=True)
489489
transformed_data.attrs[TRANSFORM_KEY] = {DEFAULT_COORDINATE_SYSTEM: Identity()}

src/spatialdata/_core/query/relational_query.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _filter_table_by_elements(
214214
# some instances have not a corresponding row in the table
215215
instances = np.setdiff1d(instances, n0)
216216
assert np.sum(to_keep) == len(instances)
217-
assert sorted(set(instances.tolist())) == sorted(set(table.obs[instance_key].tolist()))
217+
assert sorted(set(instances.tolist())) == sorted(set(table.obs[instance_key].tolist())) # type: ignore[type-var]
218218
table_df = pd.DataFrame({instance_key: table.obs[instance_key], "position": np.arange(len(instances))})
219219
merged = pd.merge(table_df, pd.DataFrame(index=instances), left_on=instance_key, right_index=True, how="right")
220220
matched_positions = merged["position"].to_numpy()
@@ -467,7 +467,11 @@ def _left_join_spatialelement_table(
467467
)
468468
continue
469469

470-
joined_indices = joined_indices.dropna() if joined_indices is not None else None
470+
if joined_indices is not None:
471+
joined_indices = joined_indices.dropna()
472+
# if nan were present, the dtype would have been changed to float
473+
if joined_indices.dtype == float:
474+
joined_indices = joined_indices.astype(int)
471475
joined_table = table[joined_indices, :].copy() if joined_indices is not None else None
472476
_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)
473477

src/spatialdata/_core/query/spatial_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,8 +700,8 @@ def _(
700700
bounding_box_mask = _bounding_box_mask_points(
701701
points=points_query_coordinate_system,
702702
axes=axes,
703-
min_coordinate=min_c,
704-
max_coordinate=max_c,
703+
min_coordinate=min_c, # type: ignore[arg-type]
704+
max_coordinate=max_c, # type: ignore[arg-type]
705705
)
706706
if len(bounding_box_mask) == 1:
707707
bounding_box_mask = bounding_box_mask[0]

src/spatialdata/_io/io_points.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44

55
import zarr
6-
from dask.dataframe import DataFrame as DaskDataFrame # type: ignore[attr-defined]
6+
from dask.dataframe import DataFrame as DaskDataFrame
77
from dask.dataframe import read_parquet
88
from ome_zarr.format import Format
99

src/spatialdata/_types.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1+
from typing import Any
2+
13
import numpy as np
24
from xarray import DataArray, DataTree
35

46
__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T"]
57

6-
try:
7-
from numpy.typing import DTypeLike, NDArray
8-
9-
ArrayLike = NDArray[np.float64]
10-
IntArrayLike = NDArray[np.int64] # or any np.integer
8+
from numpy.typing import DTypeLike, NDArray
119

12-
except (ImportError, TypeError):
13-
ArrayLike = np.ndarray # type: ignore[misc]
14-
IntArrayLike = np.ndarray # type: ignore[misc]
15-
DTypeLike = np.dtype # type: ignore[misc, assignment]
10+
ArrayLike = NDArray[np.floating[Any]]
11+
IntArrayLike = NDArray[np.integer[Any]]
1612

1713
Raster_T = DataArray | DataTree
1814
ColorLike = tuple[float, ...] | str

src/spatialdata/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]:
8080
others = list(data.dims)
8181
others.remove(axis)
8282
# mypy (luca's pycharm config) can't see the isclose method of dask array
83-
s = da.isclose(data.sum(dim=others), 0) # type: ignore[attr-defined]
83+
s = da.isclose(data.sum(dim=others), 0)
8484
# TODO: rewrite this to use dask array; can't get it to work with it
8585
x = s.compute()
8686
non_zero = np.where(x == 0)[0]

src/spatialdata/dataloader/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144
**dict(rasterize_kwargs),
145145
)
146146
if rasterize
147-
else bounding_box_query # type: ignore[assignment]
147+
else bounding_box_query
148148
)
149149
self._return = self._get_return(return_annotations, table_name)
150150
self.transform = transform

src/spatialdata/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _image_blobs(
182182
masks = []
183183
for i in range(n_channels):
184184
mask = self._generate_blobs(length=length, seed=i)
185-
mask = (mask - mask.min()) / np.ptp(mask) # type: ignore[attr-defined]
185+
mask = (mask - mask.min()) / np.ptp(mask)
186186
masks.append(mask)
187187

188188
x = np.stack(masks, axis=0)

0 commit comments

Comments
 (0)