Skip to content

Commit 7bce868

Browse files
fix join non matching table (#813)
* fix join non matching table * add test non matching element * removed unused comment --------- Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com>
1 parent 7fc5eb4 commit 7bce868

File tree

2 files changed

+81
-12
lines changed

2 files changed

+81
-12
lines changed

src/spatialdata/_core/query/relational_query.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,12 @@ def _get_joined_table_indices(
253253
mask = table_instance_key_column.isin(element_indices)
254254
if joined_indices is None:
255255
if match_rows == "left":
256-
joined_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows)
256+
_, joined_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows)
257257
else:
258258
joined_indices = table_instance_key_column[mask].index
259259
else:
260260
if match_rows == "left":
261-
add_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows)
261+
_, add_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows)
262262
joined_indices = joined_indices.append(add_indices)
263263
# in place append does not work with pd.Index
264264
else:
@@ -294,8 +294,14 @@ def _get_masked_element(
294294
mask = table_instance_key_column.isin(element_indices)
295295
masked_table_instance_key_column = table_instance_key_column[mask]
296296
mask_values = mask_values if len(mask_values := masked_table_instance_key_column.values) != 0 else None
297-
if match_rows == "right":
298-
mask_values = _match_rows(table_instance_key_column, mask, element_indices, match_rows)
297+
if match_rows in ["left", "right"]:
298+
left_index, _ = _match_rows(table_instance_key_column, mask, element_indices, match_rows)
299+
300+
if mask_values is not None and len(left_index) != len(mask_values):
301+
mask = left_index.isin(mask_values)
302+
mask_values = left_index[mask]
303+
else:
304+
mask_values = left_index
299305

300306
if isinstance(element, DaskDataFrame):
301307
return element.map_partitions(lambda df: df.loc[mask_values], meta=element)
@@ -404,6 +410,9 @@ def _inner_join_spatialelement_table(
404410
element_dict[element_type][name] = None
405411
continue
406412

413+
if joined_indices is not None:
414+
joined_indices = joined_indices.dropna() if any(joined_indices.isna()) else joined_indices
415+
407416
joined_table = table[joined_indices, :].copy() if joined_indices is not None else None
408417
_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)
409418
return element_dict, joined_table
@@ -483,22 +492,24 @@ def _match_rows(
483492
mask: pd.Series,
484493
element_indices: pd.RangeIndex,
485494
match_rows: str,
486-
) -> pd.Index:
495+
) -> tuple[pd.Index, pd.Index]:
487496
instance_id_df = pd.DataFrame(
488497
{"instance_id": table_instance_key_column[mask].values, "index_right": table_instance_key_column[mask].index}
489498
)
490499
element_index_df = pd.DataFrame({"index_left": element_indices})
491-
index_col = "index_left" if match_rows == "right" else "index_right"
492500

493-
merged_df = pd.merge(
494-
element_index_df, instance_id_df, left_on="index_left", right_on="instance_id", how=match_rows
495-
)[index_col]
501+
merged_df = pd.merge(element_index_df, instance_id_df, left_on="index_left", right_on="instance_id", how=match_rows)
502+
index_left = merged_df["index_left"]
503+
index_right = merged_df["index_right"]
496504

497505
# With labels it can be that index 0 is NaN
498-
if isinstance(merged_df.iloc[0], float) and math.isnan(merged_df.iloc[0]):
499-
merged_df = merged_df.iloc[1:]
506+
if isinstance(index_left.iloc[0], float) and math.isnan(index_left.iloc[0]):
507+
index_left = index_left.iloc[1:]
508+
509+
if isinstance(index_right.iloc[0], float) and math.isnan(index_right.iloc[0]):
510+
index_right = index_right.iloc[1:]
500511

501-
return pd.Index(merged_df)
512+
return pd.Index(index_left), pd.Index(index_right)
502513

503514

504515
class JoinTypes(Enum):

tests/core/query/test_relational_query.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,64 @@ def test_left_exclusive_and_right_join(sdata_query_aggregation):
318318
)
319319

320320

321+
def test_match_rows_inner_join_non_matching_element(sdata_query_aggregation):
322+
sdata = sdata_query_aggregation
323+
sdata["values_circles"] = sdata["values_circles"][4:]
324+
original_index = sdata["values_circles"].index
325+
reversed_instance_id = [3, 5, 8, 7, 6, 4, 1, 2, 0] + list(reversed(range(12)))
326+
sdata["table"].obs["instance_id"] = reversed_instance_id
327+
328+
element_dict, table = join_spatialelement_table(
329+
sdata=sdata,
330+
spatial_element_names="values_circles",
331+
table_name="table",
332+
how="inner",
333+
match_rows="left",
334+
)
335+
assert all(table.obs["instance_id"].values == original_index)
336+
337+
element_dict, table = join_spatialelement_table(
338+
sdata=sdata,
339+
spatial_element_names="values_circles",
340+
table_name="table",
341+
how="inner",
342+
match_rows="right",
343+
)
344+
345+
assert all(element_dict["values_circles"].index == [5, 8, 7, 6, 4])
346+
347+
348+
def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation):
349+
sdata = sdata_query_aggregation
350+
table = sdata["table"][3:]
351+
original_instance_id = table.obs["instance_id"]
352+
reversed_instance_id = [6, 7, 8, 3, 4, 5] + list(reversed(range(12)))
353+
table.obs["instance_id"] = reversed_instance_id
354+
sdata["table"] = table
355+
356+
element_dict, table = join_spatialelement_table(
357+
sdata=sdata,
358+
spatial_element_names=["values_circles", "values_polygons"],
359+
table_name="table",
360+
how="inner",
361+
match_rows="left",
362+
)
363+
364+
assert all(table.obs["instance_id"].values == original_instance_id.values)
365+
366+
element_dict, table = join_spatialelement_table(
367+
sdata=sdata,
368+
spatial_element_names=["values_circles", "values_polygons"],
369+
table_name="table",
370+
how="inner",
371+
match_rows="right",
372+
)
373+
374+
indices = element_dict["values_circles"].index.append(element_dict["values_polygons"].index)
375+
376+
assert all(indices == reversed_instance_id)
377+
378+
321379
# TODO: there is a lot of dublicate code, simplify with a function that tests both the case sdata=None and sdata=sdata
322380
def test_match_rows_join(sdata_query_aggregation):
323381
sdata = sdata_query_aggregation

0 commit comments

Comments
 (0)