Skip to content

Commit 3cdf3d8

Browse files
Pins distributed; better categorical handing for points parser (#1061)
* improve handling of categoricals for feature_key in points * pin distributed; improve warning for categorical points
1 parent db34783 commit 3cdf3d8

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"click",
2828
"dask-image",
2929
"dask>=2025.2.0,<2026.1.2",
30+
"distributed<2026.1.2",
3031
"datashader",
3132
"fsspec[s3,http]",
3233
"geopandas>=0.14",

src/spatialdata/models/models.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def parse(
242242
else:
243243
# Chunk single scale images
244244
if chunks is not None:
245+
if isinstance(chunks, tuple):
246+
chunks = {dim: chunks[index] for index, dim in enumerate(data.dims)}
245247
data = data.chunk(chunks=chunks)
246248
cls()._check_chunk_size_not_too_large(data)
247249
# recompute coordinates for (multiscale) spatial image
@@ -819,19 +821,23 @@ def _(
819821
# TODO: dask does not allow for setting divisions directly anymore. We have to decide on forcing the user.
820822
if feature_key is not None:
821823
feature_categ = dd.from_pandas(
822-
data[feature_key].astype(str).astype("category"),
824+
data[feature_key],
823825
sort=sort,
824826
**kwargs,
825827
)
826828
table[feature_key] = feature_categ
827829
elif isinstance(data, dd.DataFrame):
828830
table = data[[coordinates[ax] for ax in axes]]
829831
table.columns = axes
830-
if feature_key is not None:
831-
if data[feature_key].dtype.name == "category":
832-
table[feature_key] = data[feature_key]
833-
else:
834-
table[feature_key] = data[feature_key].astype(str).astype("category")
832+
833+
if feature_key is not None:
834+
if data[feature_key].dtype.name == "category":
835+
table[feature_key] = data[feature_key]
836+
else:
837+
# this will cause the categories to be unknown and trigger the warning (and performance slowdown) in
838+
# _add_metadata_and_validate()
839+
table[feature_key] = data[feature_key].astype(str).astype("category")
840+
835841
if instance_key is not None:
836842
table[instance_key] = data[instance_key]
837843
for c in [X, Y, Z]:
@@ -885,15 +891,20 @@ def _add_metadata_and_validate(
885891
assert instance_key in data.columns
886892
data.attrs[ATTRS_KEY][cls.INSTANCE_KEY] = instance_key
887893

888-
for c in data.columns:
889-
# Here we are explicitly importing the categories
890-
# but it is a convenient way to ensure that the categories are known.
891-
# It also just changes the state of the series, so it is not a big deal.
892-
if isinstance(data[c].dtype, CategoricalDtype) and not data[c].cat.known:
893-
try:
894-
data[c] = data[c].cat.set_categories(data[c].compute().cat.categories)
895-
except ValueError:
896-
logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")
894+
if (
895+
feature_key is not None
896+
and isinstance(data[feature_key].dtype, CategoricalDtype)
897+
and not data[feature_key].cat.known
898+
):
899+
logger.warning(
900+
f"The `feature_key` column {feature_key} is categorical with unknown categories. "
901+
"Please ensure the categories are known before calling `PointsModel.parse()` to "
902+
"avoid significant performance implications due to the need for dask to compute "
903+
"the categories. If you did not use PointsModel.parse() explicitly in your code ("
904+
"e.g. this message is coming from a reader in `spatialdata_io`), please report "
905+
"this finding."
906+
)
907+
data[feature_key] = data[feature_key].cat.set_categories(data[feature_key].compute().cat.categories)
897908

898909
_parse_transformations(data, transformations)
899910
cls.validate(data)
@@ -1153,6 +1164,9 @@ def parse(
11531164
The parsed data.
11541165
"""
11551166
validate_table_attr_keys(adata)
1167+
# Convert view to actual copy to avoid ImplicitModificationWarning when modifying .uns
1168+
if adata.is_view:
1169+
adata = adata.copy()
11561170
# either all live in adata.uns or all be passed in as argument
11571171
n_args = sum([region is not None, region_key is not None, instance_key is not None])
11581172
if n_args == 0:

0 commit comments

Comments
 (0)