Skip to content

Commit ea7dc65

Browse files
Merge pull request #876 from roboflow/fix-aggregator-init-params
Fix model_id bug with InferenceAggregator block
2 parents ebffba0 + b785224 commit ea7dc65

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

inference/core/workflows/core_steps/sinks/roboflow/model_monitoring_inference_aggregator/v1.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def __init__(
214214
api_key: Optional[str],
215215
background_tasks: Optional[BackgroundTasks],
216216
thread_pool_executor: Optional[ThreadPoolExecutor],
217-
model_id: str,
218217
):
219218
if api_key is None:
220219
raise ValueError(
@@ -228,7 +227,6 @@ def __init__(
228227
self._background_tasks = background_tasks
229228
self._thread_pool_executor = thread_pool_executor
230229
self._predictions_aggregator = PredictionsAggregator()
231-
self._model_id = model_id
232230

233231
@classmethod
234232
def get_init_parameters(cls) -> List[str]:
@@ -244,10 +242,11 @@ def run(
244242
predictions: Union[sv.Detections, dict],
245243
frequency: int,
246244
unique_aggregator_key: str,
245+
model_id: str,
247246
) -> BlockResult:
248247
self._last_report_time_cache_key = f"workflows:steps_cache:roboflow_core/model_monitoring_inference_aggregator@v1:{unique_aggregator_key}:last_report_time"
249248
if predictions:
250-
self._predictions_aggregator.collect(predictions, self._model_id)
249+
self._predictions_aggregator.collect(predictions, model_id)
251250
if not self._is_in_reporting_range(frequency):
252251
return {
253252
"error_status": False,

tests/workflows/unit_tests/core_steps/sinks/roboflow/test_model_monitoring_inference_aggregator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ def test_run_not_in_reporting_range_success(
5454
api_key="my_api_key",
5555
background_tasks=None,
5656
thread_pool_executor=None,
57-
model_id="my_model_id",
5857
)
5958
result = block.run(
6059
fire_and_forget=True,
6160
frequency=10,
6261
predictions=predictions,
6362
unique_aggregator_key=unique_aggregator_key,
63+
model_id="my_model_id",
6464
)
6565

6666
# then
@@ -121,13 +121,13 @@ def test_run_in_reporting_range_success_with_object_detection(
121121
api_key=api_key,
122122
background_tasks=None,
123123
thread_pool_executor=None,
124-
model_id="construction-safety/10",
125124
)
126125
result = block.run(
127126
fire_and_forget=False,
128127
frequency=10,
129128
predictions=predictions,
130129
unique_aggregator_key=unique_aggregator_key,
130+
model_id="construction-safety/10",
131131
)
132132

133133
# then
@@ -217,13 +217,13 @@ def test_run_in_reporting_range_success_with_single_label_classification(
217217
api_key=api_key,
218218
background_tasks=None,
219219
thread_pool_executor=None,
220-
model_id="pills-classification/1",
221220
)
222221
result = block.run(
223222
fire_and_forget=False,
224223
frequency=10,
225224
predictions=predictions,
226225
unique_aggregator_key=unique_aggregator_key,
226+
model_id="pills-classification/1",
227227
)
228228

229229
# then
@@ -313,13 +313,13 @@ def test_run_in_reporting_range_success_with_multi_label_classification(
313313
api_key=api_key,
314314
background_tasks=None,
315315
thread_pool_executor=None,
316-
model_id="animals/32",
317316
)
318317
result = block.run(
319318
fire_and_forget=False,
320319
frequency=10,
321320
predictions=predictions,
322321
unique_aggregator_key=unique_aggregator_key,
322+
model_id="animals/32",
323323
)
324324

325325
# then
@@ -415,13 +415,13 @@ def test_send_inference_results_to_model_monitoring_failure(
415415
api_key=api_key,
416416
background_tasks=None,
417417
thread_pool_executor=None,
418-
model_id="my_model_id",
419418
)
420419
result = block.run(
421420
fire_and_forget=False,
422421
frequency=1,
423422
predictions=predictions,
424423
unique_aggregator_key=unique_aggregator_key,
424+
model_id="my_model_id",
425425
)
426426

427427
# then
@@ -479,13 +479,13 @@ def test_run_when_not_in_reporting_range(
479479
api_key=api_key,
480480
background_tasks=None,
481481
thread_pool_executor=None,
482-
model_id="my_model_id",
483482
)
484483
result = block.run(
485484
fire_and_forget=False,
486485
frequency=10,
487486
predictions=predictions,
488487
unique_aggregator_key=unique_aggregator_key,
488+
model_id="my_model_id",
489489
)
490490

491491
# then
@@ -545,13 +545,13 @@ def test_run_when_fire_and_forget_with_background_tasks(
545545
api_key=api_key,
546546
background_tasks=background_tasks,
547547
thread_pool_executor=None,
548-
model_id="my_model_id",
549548
)
550549
result = block.run(
551550
fire_and_forget=True,
552551
frequency=10,
553552
predictions=predictions,
554553
unique_aggregator_key=unique_aggregator_key,
554+
model_id="my_model_id",
555555
)
556556

557557
# then
@@ -609,13 +609,13 @@ def test_run_when_fire_and_forget_with_thread_pool(
609609
api_key=api_key,
610610
background_tasks=None,
611611
thread_pool_executor=thread_pool_executor,
612-
model_id="my_model_id",
613612
)
614613
result = block.run(
615614
fire_and_forget=True,
616615
frequency=10,
617616
predictions=predictions,
618617
unique_aggregator_key=unique_aggregator_key,
618+
model_id="my_model_id",
619619
)
620620

621621
# then

0 commit comments

Comments
 (0)