Skip to content

Commit 33b13b2

Browse files
committed
feat: add verify=true|false or remove to allow modifying the verified attribute
1 parent ae5f529 commit 33b13b2

File tree

3 files changed

+51
-19
lines changed

3 files changed

+51
-19
lines changed

src/app/main.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ async def assign_label_by_id(
229229
return {"message": f"Error: {ex}"}
230230

231231
@app.post("/label/cluster/{label}",
232-
summary="Assign a label to a localization by cluster name",
232+
summary="Assign a label to a localization by cluster name. Set the verified attribute to true (default), false, or leave off verified=true|false leave verified attribute as-is",
233233
status_code=status.HTTP_200_OK)
234234
async def assign_label_by_cluster(
235235
label: str, model: LocClusterFilterModel, background_tasks: BackgroundTasks
@@ -261,14 +261,22 @@ async def assign_label_by_cluster(
261261

262262
# Clear the kwargs and add the media name filter
263263
kwargs.clear()
264-
kwargs["attribute"] = [f"cluster::{model.cluster_name}", "verified::False"]
265264
if version_id:
266265
kwargs["version"] = [version_id]
267-
num_boxes = await get_localization_count(api, spec, **kwargs)
266+
counts = {}
267+
for verified in ("True", "False"):
268+
kwargs["attribute"] = [f"cluster::{model.cluster_name}", f"verified::{verified}"]
269+
counts[verified] = await get_localization_count(api, spec, **kwargs)
270+
if model.verify is not None:
271+
kwargs["attribute"] = [f"cluster::{model.cluster_name}", f"verified::{str(bool(model.verify))}"]
272+
else:
273+
kwargs["attribute"] = [f"cluster::{model.cluster_name}"]
268274

269275
if model.dry_run:
276+
num_verified = counts["True"]
277+
num_unverified = counts["False"]
270278
return {
271-
"message": f'{num_boxes} unverified localizations in '
279+
"message": f'{num_unverified} unverified {num_verified} verified localizations in '
272280
f'cluster {model.cluster_name} and '
273281
f'{model.version_name if version_id else "all versions"} in {num_media} medias'
274282
}
@@ -279,6 +287,7 @@ async def assign_label_by_cluster(
279287
return {
280288
"message": f"Queued modification of localizations in cluster {model.cluster_name} and "
281289
f'{model.version_name if version_id else "all versions"} to label {label}'
290+
f' and verify {model.verify if model.verify is not None else "unchanged"}'
282291
}
283292
except Exception as ex:
284293
return {"message": f"Error: {ex}"}
@@ -401,7 +410,7 @@ async def media_count_by_media_filename(item: MediaNameFilterModelBase):
401410

402411

403412
@app.delete("/localizations/filename",
404-
summary="Delete localizations by media filename and filter type Includes/Equals",
413+
summary="Delete localizations by media filename and filter type Includes/Equals. ONLY deletes unverified localizations",
405414
status_code=status.HTTP_200_OK)
406415
async def localizations_by_media_filename(item: MediaNameFilterModel, background_tasks: BackgroundTasks):
407416
model = MediaNameFilterModel(**jsonable_encoder(item)) # Convert to a model
@@ -457,7 +466,7 @@ async def localizations_by_media_filename(item: MediaNameFilterModel, background
457466
return {"message": f"Queued deletion of localizations in medias by filename {model.media_name}"}
458467

459468
@app.delete("/localizations/filename_label",
460-
summary="Delete localizations by media filename Includes/Equals and label",
469+
summary="Delete localizations by media filename Includes/Equals and label. ONLY deletes unverified localizations",
461470
status_code=status.HTTP_200_OK)
462471
async def delete_localizations_by_media_filename_and_label(
463472
model: LocLabelFilterModel, background_tasks: BackgroundTasks
@@ -511,7 +520,7 @@ async def delete_localizations_by_media_filename_and_label(
511520
return {"message": f"Error: {ex}"}
512521

513522
@app.delete("/localizations/filename_cluster",
514-
summary="Delete localizations by media filename Includes/Equals and cluster name",
523+
summary="Delete localizations by media filename Includes/Equals and cluster name. ONLY deletes unverified localizations",
515524
status_code=status.HTTP_200_OK)
516525
async def delete_localizations_by_media_filename_and_cluster(
517526
model: LocMediaClusterFilterModel, background_tasks: BackgroundTasks

src/app/ops/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
# Description: models for common bulk operations on tator
44

55
from enum import unique, Enum
6-
from pydantic import BaseModel
6+
from typing import Optional
7+
8+
from pydantic import BaseModel, field_validator
79
from app.conf import default_project
810

911

@@ -65,6 +67,13 @@ class LocClusterFilterModel(BaseModel):
6567
version_name: str | None = "Baseline"
6668
project_name: str | None = default_project
6769
dry_run: bool | None = True
70+
verify: Optional[bool] = None
71+
72+
@field_validator('verify', mode='before')
73+
def set_default_true_if_present(cls, v):
74+
if v is None:
75+
return None
76+
return bool(v) if v != '' else True
6877

6978
class LocMediaClusterFilterModel(BaseModel):
7079
filter_media: str | None = FilterType.Equals

src/app/ops/modifications.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,20 @@ async def assign_cluster_label(model: LocClusterFilterModel, label: str, api: ta
9797
num_modified += len(localizations)
9898
debug(f"Found {len(localizations)} localizations that include {model.cluster_name} ...")
9999

100-
# Bulk update boxes by IDs, set verified to True
100+
# Bulk update boxes by IDs
101101
params = {"type": spec.box_type}
102-
id_bulk_patch = {
103-
"attributes": {"Label": label, "verified": True},
104-
"ids": [l.id for l in localizations],
105-
"in_place": 1,
106-
}
102+
if model.verify is None:
103+
id_bulk_patch = {
104+
"attributes": {"Label": label},
105+
"ids": [l.id for l in localizations],
106+
"in_place": 1,
107+
}
108+
else:
109+
id_bulk_patch = {
110+
"attributes": {"Label": label, "verified": model.verify},
111+
"ids": [l.id for l in localizations],
112+
"in_place": 1,
113+
}
107114
try:
108115
info(id_bulk_patch)
109116
response = api.update_localization_list(project=spec.project_id, **params, localization_bulk_update=id_bulk_patch)
@@ -189,11 +196,18 @@ async def assign_cluster_media_label(model: LocMediaClusterFilterModel, label: s
189196

190197
# Bulk update boxes by IDs, set verified to True
191198
params = {"type": spec.box_type}
192-
id_bulk_patch = {
193-
"attributes": {"Label": label, "verified": True},
194-
"ids": [l.id for l in localizations],
195-
"in_place": 1,
196-
}
199+
if model.verify is not None:
200+
id_bulk_patch = {
201+
"attributes": {"Label": label, "verified": model.verify},
202+
"ids": [l.id for l in localizations],
203+
"in_place": 1,
204+
}
205+
else:
206+
id_bulk_patch = {
207+
"attributes": {"Label": label},
208+
"ids": [l.id for l in localizations],
209+
"in_place": 1,
210+
}
197211
try:
198212
info(id_bulk_patch)
199213
response = api.update_localization_list(project=spec.project_id, **params, localization_bulk_update=id_bulk_patch)

0 commit comments

Comments
 (0)