@@ -265,6 +265,9 @@ def _merge_keypoint_detections(
265265 "class_id" : group [0 ]["class_id" ],
266266 "mask" : None ,
267267 "keypoint_data" : merged_kp_data ,
268+ "detection_data" : group [0 ].get (
269+ "detection_data" , {}
270+ ), # Preserve first detection's metadata
268271 }
269272 )
270273
@@ -408,6 +411,23 @@ def merge_crop_predictions(
408411 "keypoints_confidence"
409412 ][j ]
410413
414+ # Collect per-detection data fields to preserve individual detection metadata
415+ # This is crucial for preserving class_name and other fields when multiple
416+ # detections have the same class_id but different values
417+ detection_data = {}
418+ for key in child_pred .data .keys ():
419+ if key not in [
420+ "detection_id" ,
421+ "parent_id" ,
422+ "inference_id" ,
423+ "keypoints_xy" ,
424+ "keypoints_class_name" ,
425+ "keypoints_class_id" ,
426+ "keypoints_confidence" ,
427+ ]:
428+ if j < len (child_pred .data [key ]):
429+ detection_data [key ] = child_pred .data [key ][j ]
430+
411431 if has_masks and child_pred .mask is not None :
412432 # Instance segmentation - transform mask
413433 mask = child_pred .mask [j ]
@@ -426,6 +446,7 @@ def merge_crop_predictions(
426446 "class_id" : class_id ,
427447 "bbox" : None , # Will compute from mask
428448 "keypoint_data" : keypoint_data ,
449+ "detection_data" : detection_data , # Store per-detection metadata
429450 }
430451 )
431452 else :
@@ -446,6 +467,7 @@ def merge_crop_predictions(
446467 "class_id" : class_id ,
447468 "mask" : None ,
448469 "keypoint_data" : keypoint_data ,
470+ "detection_data" : detection_data , # Store per-detection metadata
449471 }
450472 )
451473
@@ -616,8 +638,11 @@ def merge_crop_predictions(
616638 # Prediction type should be 'instance-segmentation'
617639 merged_data [key ].append ("instance-segmentation" )
618640 else :
619- # For other fields like class_name, use the value associated with this class_id
620- if (
641+ # For other fields like class_name, check pred dict first (per-detection data)
642+ # then fall back to class_id_to_data (class-level defaults)
643+ if key in pred .get ("detection_data" , {}):
644+ merged_data [key ].append (pred ["detection_data" ][key ])
645+ elif (
621646 pred ["class_id" ] in class_id_to_data
622647 and key in class_id_to_data [pred ["class_id" ]]
623648 ):
@@ -789,6 +814,9 @@ def _merge_overlapping_masks(
789814 "polygon" : poly ,
790815 "confidence" : pred ["confidence" ],
791816 "class_id" : pred ["class_id" ],
817+ "detection_data" : pred .get (
818+ "detection_data" , {}
819+ ), # Preserve metadata
792820 }
793821 )
794822
@@ -829,6 +857,9 @@ def _merge_overlapping_masks(
829857 "mask" : mask ,
830858 "confidence" : merged_confidence ,
831859 "class_id" : class_id ,
860+ "detection_data" : group [0 ].get (
861+ "detection_data" , {}
862+ ), # Preserve first detection's metadata
832863 }
833864 )
834865 else :
@@ -839,6 +870,9 @@ def _merge_overlapping_masks(
839870 "mask" : mask ,
840871 "confidence" : merged_confidence ,
841872 "class_id" : class_id ,
873+ "detection_data" : group [0 ].get (
874+ "detection_data" , {}
875+ ), # Preserve first detection's metadata
842876 }
843877 )
844878
@@ -891,7 +925,14 @@ def _merge_overlapping_bboxes(
891925 merged_bbox = np .array ([min (x_mins ), min (y_mins ), max (x_maxs ), max (y_maxs )])
892926
893927 merged_results .append (
894- {"bbox" : merged_bbox , "confidence" : merged_confidence , "class_id" : class_id }
928+ {
929+ "bbox" : merged_bbox ,
930+ "confidence" : merged_confidence ,
931+ "class_id" : class_id ,
932+ "detection_data" : group [0 ].get (
933+ "detection_data" , {}
934+ ), # Preserve first detection's metadata
935+ }
895936 )
896937
897938 return merged_results
0 commit comments