diff --git a/rfdetr/engine.py b/rfdetr/engine.py index cb589dfa..693bcfbf 100644 --- a/rfdetr/engine.py +++ b/rfdetr/engine.py @@ -185,7 +185,7 @@ def coco_extended_metrics(coco_eval): iou_thrs, rec_thrs = coco_eval.params.iouThrs, coco_eval.params.recThrs iou50_idx, area_idx, maxdet_idx = ( - int(np.argwhere(np.isclose(iou_thrs, 0.50))), 0, 2) + int(np.argmax(np.isclose(iou_thrs, 0.50))), 0, 2) P = coco_eval.eval["precision"] S = coco_eval.eval["scores"] @@ -338,4 +338,4 @@ def evaluate(model, criterion, postprocess, data_loader, base_ds, device, args=N if "segm" in iou_types: results_json = coco_extended_metrics(coco_evaluator.coco_eval["segm"]) stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() - return stats, coco_evaluator \ No newline at end of file + return stats, coco_evaluator