@@ -253,12 +253,12 @@ def _get_joined_table_indices(
253253 mask = table_instance_key_column .isin (element_indices )
254254 if joined_indices is None :
255255 if match_rows == "left" :
256- joined_indices = _match_rows (table_instance_key_column , mask , element_indices , match_rows )
256+ _ , joined_indices = _match_rows (table_instance_key_column , mask , element_indices , match_rows )
257257 else :
258258 joined_indices = table_instance_key_column [mask ].index
259259 else :
260260 if match_rows == "left" :
261- add_indices = _match_rows (table_instance_key_column , mask , element_indices , match_rows )
261+ _ , add_indices = _match_rows (table_instance_key_column , mask , element_indices , match_rows )
262262 joined_indices = joined_indices .append (add_indices )
263263 # in place append does not work with pd.Index
264264 else :
@@ -294,8 +294,14 @@ def _get_masked_element(
294294 mask = table_instance_key_column .isin (element_indices )
295295 masked_table_instance_key_column = table_instance_key_column [mask ]
296296 mask_values = mask_values if len (mask_values := masked_table_instance_key_column .values ) != 0 else None
297- if match_rows == "right" :
298- mask_values = _match_rows (table_instance_key_column , mask , element_indices , match_rows )
297+ if match_rows in ["left" , "right" ]:
298+ left_index , _ = _match_rows (table_instance_key_column , mask , element_indices , match_rows )
299+
300+ if mask_values is not None and len (left_index ) != len (mask_values ):
301+ mask = left_index .isin (mask_values )
302+ mask_values = left_index [mask ]
303+ else :
304+ mask_values = left_index
299305
300306 if isinstance (element , DaskDataFrame ):
301307 return element .map_partitions (lambda df : df .loc [mask_values ], meta = element )
@@ -404,6 +410,9 @@ def _inner_join_spatialelement_table(
404410 element_dict [element_type ][name ] = None
405411 continue
406412
413+ if joined_indices is not None :
414+ joined_indices = joined_indices .dropna () if any (joined_indices .isna ()) else joined_indices
415+
407416 joined_table = table [joined_indices , :].copy () if joined_indices is not None else None
408417 _inplace_fix_subset_categorical_obs (subset_adata = joined_table , original_adata = table )
409418 return element_dict , joined_table
@@ -483,22 +492,24 @@ def _match_rows(
483492 mask : pd .Series ,
484493 element_indices : pd .RangeIndex ,
485494 match_rows : str ,
486- ) -> pd .Index :
495+ ) -> tuple [ pd .Index , pd . Index ] :
487496 instance_id_df = pd .DataFrame (
488497 {"instance_id" : table_instance_key_column [mask ].values , "index_right" : table_instance_key_column [mask ].index }
489498 )
490499 element_index_df = pd .DataFrame ({"index_left" : element_indices })
491- index_col = "index_left" if match_rows == "right" else "index_right"
492500
493- merged_df = pd .merge (
494- element_index_df , instance_id_df , left_on = "index_left" , right_on = "instance_id" , how = match_rows
495- )[ index_col ]
501+ merged_df = pd .merge (element_index_df , instance_id_df , left_on = "index_left" , right_on = "instance_id" , how = match_rows )
502+ index_left = merged_df [ "index_left" ]
503+ index_right = merged_df [ "index_right" ]
496504
497505 # With labels it can be that index 0 is NaN
498- if isinstance (merged_df .iloc [0 ], float ) and math .isnan (merged_df .iloc [0 ]):
499- merged_df = merged_df .iloc [1 :]
506+ if isinstance (index_left .iloc [0 ], float ) and math .isnan (index_left .iloc [0 ]):
507+ index_left = index_left .iloc [1 :]
508+
509+ if isinstance (index_right .iloc [0 ], float ) and math .isnan (index_right .iloc [0 ]):
510+ index_right = index_right .iloc [1 :]
500511
501- return pd .Index (merged_df )
512+ return pd .Index (index_left ), pd . Index ( index_right )
502513
503514
504515class JoinTypes (Enum ):
0 commit comments