@@ -217,9 +217,9 @@ def rasterize(
217217 The table optionally containing the `value_key` and the name of the table in the returned `SpatialData` object.
218218 Must be `None` when `data` is a `SpatialData` object, otherwise it assumes the default value of `'table'`.
219219 return_regions_as_labels
220- By default, single-scale images of shape `(c, y, x)` are returned. If `True`, returns labels and shapes as
221- labels of shape `(y, x)` as opposed to an image of shape `(c, y, x)`. Points and images are always returned
222- as images, and multiscale raster data is always returned as single-scale data.
220+ By default, single-scale images of shape `(c, y, x)` are returned. If `True`, returns labels, shapes and points
221+ as labels of shape `(y, x)` as opposed to an image of shape `(c, y, x)`. Images are always returned as images,
222+ and multiscale raster data is always returned as single-scale data.
223223 agg_func
224224 Available only when rasterizing points and shapes. A reduction function from datashader (its name, or a
225225 `Callable`). See the notes for more details on the default behavior.
@@ -234,6 +234,11 @@ def rasterize(
234234 into a `DataArray` (not a `DataTree`). So if a `SpatialData` object with elements is passed, a `SpatialData` object
235235 with single-scale images and labels will be returned.
236236
237+ When `return_regions_as_labels` is `True`, the returned `DataArray` object will have an attribute called
238+ `label_index_to_category` that maps the label index to the category name. You can access it via
239+ `returned_data.attrs["label_index_to_category"]`. The returned labels will start from 1 (0 is reserved for the
240+ background), and will be contiguous.
241+
237242 Notes
238243 -----
239244 For images and labels, the parameters `value_key`, `table_name`, `agg_func`, and `return_single_channel` are not
@@ -587,7 +592,7 @@ def rasterize_images_labels(
587592 )
588593 assert isinstance (transformed_dask , DaskArray )
589594 channels = xdata .coords ["c" ].values if schema in (Image2DModel , Image3DModel ) else None
590- transformed_data = schema .parse (transformed_dask , dims = xdata .dims , c_coords = channels ) # type: ignore[call-arg,arg-type ]
595+ transformed_data = schema .parse (transformed_dask , dims = xdata .dims , c_coords = channels ) # type: ignore[call-arg]
591596
592597 if target_coordinate_system != "global" :
593598 remove_transformation (transformed_data , "global" )
@@ -650,7 +655,7 @@ def rasterize_shapes_points(
650655 if value_key is not None :
651656 kwargs = {"sdata" : sdata , "element_name" : element_name } if element_name is not None else {"element" : data }
652657 data [VALUES_COLUMN ] = get_values (value_key , table_name = table_name , ** kwargs ).iloc [:, 0 ] # type: ignore[arg-type, union-attr]
653- elif isinstance (data , GeoDataFrame ):
658+ elif isinstance (data , GeoDataFrame ) or isinstance ( data , DaskDataFrame ) and return_regions_as_labels is True :
654659 value_key = VALUES_COLUMN
655660 data [VALUES_COLUMN ] = data .index .astype ("category" )
656661 else :
@@ -706,6 +711,14 @@ def rasterize_shapes_points(
706711 agg = agg .fillna (0 )
707712
708713 if return_regions_as_labels :
714+ if label_index_to_category is not None :
715+ max_label = next (iter (reversed (label_index_to_category .keys ())))
716+ else :
717+ max_label = int (agg .max ().values )
718+ max_uint16 = np .iinfo (np .uint16 ).max
719+ if max_label > max_uint16 :
720+ raise ValueError (f"Maximum label index is { max_label } . Values higher than { max_uint16 } are not supported." )
721+ agg = agg .astype (np .uint16 )
709722 return Labels2DModel .parse (agg , transformations = transformations )
710723
711724 agg = agg .expand_dims (dim = {"c" : 1 }).transpose ("c" , "y" , "x" )
0 commit comments