Skip to content

Commit 9ebd9a8

Browse files
feat: use as_dict in all Check CRUD methods
1 parent d02c992 commit 9ebd9a8

File tree

9 files changed

+83
-66
lines changed

9 files changed

+83
-66
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ async def _fake_check(
284284
"geojson_url": "https://example.org/file.geojson" if pmtiles_url else None,
285285
"geojson_size": 1024 if geojson_url else None,
286286
}
287-
check: dict = await Check.insert(data=data, returning="*")
287+
check: dict = await Check.insert(data=data, returning="*", as_dict=True)
288288
data["id"] = check["id"]
289289
if check.get("dataset_id"):
290290
data["dataset_id"] = check["dataset_id"]

tests/test_crawl/test_crawl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,10 +687,10 @@ async def test_wrong_url_in_catalog(
687687
if url_changed:
688688
r = await Resource.get(resource_id=RESOURCE_ID, column_name="url")
689689
assert r["url"] == new_url
690-
check = await Check.get_by_resource_id(RESOURCE_ID)
690+
check = await Check.get_by_resource_id(RESOURCE_ID, as_dict=True)
691691
assert check.get("parsing_finished_at")
692692
else:
693-
check = await Check.get_by_resource_id(RESOURCE_ID)
693+
check = await Check.get_by_resource_id(RESOURCE_ID, as_dict=True)
694694
assert check["status"] == 404
695695

696696

udata_hydra/analysis/csv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def analyse_csv(
114114
timer.mark("download-file")
115115

116116
check = await Check.update(
117-
check["id"], {"parsing_started_at": datetime.now(timezone.utc)}, return_as_dict=True
117+
check["id"], {"parsing_started_at": datetime.now(timezone.utc)}, as_dict=True
118118
) # type: ignore
119119

120120
# Launch csv-detective against given file
@@ -158,7 +158,7 @@ async def analyse_csv(
158158
resource_id=resource_id,
159159
debug_insert=debug_insert,
160160
)
161-
check = await Check.update(check["id"], {"parsing_table": table_name}, return_as_dict=True) # type: ignore
161+
check = await Check.update(check["id"], {"parsing_table": table_name}, as_dict=True) # type: ignore
162162
timer.mark("csv-to-db")
163163

164164
try:
@@ -202,7 +202,7 @@ async def analyse_csv(
202202
{
203203
"parsing_finished_at": datetime.now(timezone.utc),
204204
},
205-
return_as_dict=True,
205+
as_dict=True,
206206
) # type: ignore
207207
await csv_to_db_index(table_name, csv_inspection, check)
208208

udata_hydra/analysis/geojson.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def analyse_geojson(
6868
timer.mark("download-file")
6969

7070
check = await Check.update(
71-
check["id"], {"parsing_started_at": datetime.now(timezone.utc)}, return_as_dict=True
71+
check["id"], {"parsing_started_at": datetime.now(timezone.utc)}, as_dict=True
7272
) # type: ignore
7373

7474
# Convert to PMTiles
@@ -95,7 +95,7 @@ async def analyse_geojson(
9595
"pmtiles_url": pmtiles_url,
9696
"pmtiles_size": pmtiles_size,
9797
},
98-
return_as_dict=True,
98+
as_dict=True,
9999
) # type: ignore
100100

101101
except (ParseException, IOException) as e:

udata_hydra/cli.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ async def check_resource(resource_id: str, method: str = "get", force_analysis:
185185
@cli(name="analyse-resource")
186186
async def analyse_resource_cli(resource_id: str):
187187
"""Trigger a resource analysis, mainly useful for local debug (with breakpoints)"""
188-
check: Record | None = await Check.get_by_resource_id(resource_id)
188+
check: Record | None = await Check.get_by_resource_id(resource_id) # type: ignore
189189
if not check:
190190
log.error("Could not find a check linked to the specified resource ID")
191191
return
@@ -210,21 +210,19 @@ async def analyse_csv_cli(
210210

211211
# Try to get check from check_id
212212
if check_id:
213-
record = await Check.get_by_id(int(check_id), with_deleted=True)
214-
check = dict(record) if record else None
213+
check: Record | None = await Check.get_by_id(int(check_id), with_deleted=True) # type: ignore
215214

216215
# Try to get check from URL
217216
if not check and url:
218-
records = await Check.get_by_url(url)
219-
if records:
220-
if len(records) > 1:
217+
checks: list[Record] | None = await Check.get_by_url(url) # type: ignore
218+
if checks:
219+
if len(checks) > 1:
221220
log.warning(f"Multiple checks found for URL {url}, using the latest one")
222-
check = dict(records[0])
221+
check = checks[0]
223222

224223
# Try to get check from resource_id
225224
if not check and resource_id:
226-
record = await Check.get_by_resource_id(resource_id)
227-
check = dict(record) if record else None
225+
check: Record | None = await Check.get_by_resource_id(resource_id) # type: ignore
228226

229227
# We cannot get a check, it's an external URL analysis, we need to create a temporary check
230228
if not check and url:
@@ -239,13 +237,13 @@ async def analyse_csv_cli(
239237
"timeout": False,
240238
},
241239
returning="*",
242-
)
240+
) # type: ignore
243241

244242
elif not check:
245243
log.error("Could not find a check for the specified parameters")
246244
return
247245

248-
await analyse_csv(check=check, debug_insert=debug_insert)
246+
await analyse_csv(check=dict(check), debug_insert=debug_insert)
249247
log.info("CSV analysis completed")
250248

251249
if url and tmp_resource_id:
@@ -259,7 +257,7 @@ async def analyse_csv_cli(
259257
await csv_pool.execute(f"DELETE FROM tables_index WHERE parsing_table='{table_hash}'")
260258

261259
# Clean up the temporary resource and temporary check from catalog
262-
check = await Check.get_by_resource_id(tmp_resource_id)
260+
check: Record | None = await Check.get_by_resource_id(tmp_resource_id) # type: ignore
263261
if check:
264262
await Check.delete(check["id"])
265263
await Resource.delete(resource_id=tmp_resource_id, hard_delete=True)
@@ -284,14 +282,14 @@ async def analyse_geojson_cli(
284282
assert check_id or url or resource_id
285283
check = None
286284
if check_id:
287-
check: Record | None = await Check.get_by_id(int(check_id), with_deleted=True)
285+
check: Record | None = await Check.get_by_id(int(check_id), with_deleted=True) # type: ignore
288286
if not check and url:
289-
checks: list[Record] | None = await Check.get_by_url(url)
287+
checks: list[Record] | None = await Check.get_by_url(url) # type: ignore
290288
if checks and len(checks) > 1:
291289
log.warning(f"Multiple checks found for URL {url}, using the latest one")
292290
check = checks[0] if checks else None
293291
if not check and resource_id:
294-
check: Record | None = await Check.get_by_resource_id(resource_id)
292+
check: Record | None = await Check.get_by_resource_id(resource_id) # type: ignore
295293
if not check:
296294
if check_id:
297295
log.error("Could not retrieve the specified check")

udata_hydra/crawl/preprocess_check_data.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import json
22
from datetime import datetime, timezone
33

4-
from asyncpg import Record
5-
64
from udata_hydra.crawl.calculate_next_check import calculate_next_check_date
75
from udata_hydra.crawl.helpers import get_content_type_from_header, is_valid_status
86
from udata_hydra.db.check import Check
@@ -26,14 +24,14 @@ async def preprocess_check_data(dataset_id: str, check_data: dict) -> tuple[dict
2624

2725
check_data["resource_id"] = str(check_data["resource_id"])
2826

29-
last_check: dict | None = None
30-
last_check_record: Record | None = await Check.get_by_resource_id(check_data["resource_id"])
31-
if last_check_record:
32-
last_check = dict(last_check_record)
27+
last_check: dict | None = await Check.get_by_resource_id(
28+
check_data["resource_id"], as_dict=True
29+
) # type: ignore
3330

3431
has_changed: bool = await has_check_changed(check_data, last_check)
3532
check_data["next_check_at"] = calculate_next_check_date(has_changed, last_check, None)
36-
new_check: dict = await Check.insert(data=check_data, returning="*")
33+
34+
new_check: dict = await Check.insert(data=check_data, returning="*", as_dict=True) # type: ignore
3735

3836
if has_changed:
3937
queue.enqueue(

udata_hydra/db/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def compute_update_query(table_name: str, data: dict, returning: str = "*") -> s
3838
"""
3939

4040

41-
async def update_table_record(table_name: str, record_id: int, data: dict) -> Record:
41+
async def update_table_record(table_name: str, record_id: int, data: dict) -> Record | None:
4242
data = convert_dict_values_to_json(data)
4343
q = compute_update_query(table_name, data)
4444
pool = await context.pool()

udata_hydra/db/check.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,24 @@
1313
class Check:
1414
"""Represents a check in the "checks" DB table"""
1515

16+
@staticmethod
17+
def _convert_to_dict_if_needed(result: Record | None, as_dict: bool) -> Record | dict | None:
18+
if as_dict and result:
19+
return dict(result)
20+
return result
21+
22+
@staticmethod
23+
def _convert_list_to_dict_if_needed(
24+
results: list[Record], as_dict: bool
25+
) -> list[Record] | list[dict]:
26+
if as_dict:
27+
return [dict(result) for result in results]
28+
return results
29+
1630
@classmethod
17-
async def get_by_id(cls, check_id: int, with_deleted: bool = False) -> Record | None:
31+
async def get_by_id(
32+
cls, check_id: int, with_deleted: bool = False, as_dict: bool = False
33+
) -> Record | dict | None:
1834
pool = await context.pool()
1935
async with pool.acquire() as connection:
2036
q = """
@@ -24,12 +40,13 @@ async def get_by_id(cls, check_id: int, with_deleted: bool = False) -> Record |
2440
"""
2541
if not with_deleted:
2642
q += " AND catalog.deleted = FALSE"
27-
return await connection.fetchrow(q, check_id)
43+
result = await connection.fetchrow(q, check_id)
44+
return cls._convert_to_dict_if_needed(result, as_dict)
2845

2946
@classmethod
3047
async def get_by_resource_id(
31-
cls, resource_id: str, with_deleted: bool = False
32-
) -> Record | None:
48+
cls, resource_id: str, with_deleted: bool = False, as_dict: bool = False
49+
) -> Record | dict | None:
3350
pool = await context.pool()
3451
async with pool.acquire() as connection:
3552
q = """
@@ -39,23 +56,25 @@ async def get_by_resource_id(
3956
"""
4057
if not with_deleted:
4158
q += " AND catalog.deleted = FALSE"
42-
return await connection.fetchrow(q, resource_id)
59+
result = await connection.fetchrow(q, resource_id)
60+
return cls._convert_to_dict_if_needed(result, as_dict)
4361

4462
@classmethod
45-
async def get_by_url(cls, url: str) -> list[Record]:
63+
async def get_by_url(cls, url: str, as_dict: bool = False) -> list[Record] | list[dict]:
4664
pool = await context.pool()
4765
async with pool.acquire() as connection:
4866
q = """
4967
SELECT * FROM checks
5068
WHERE url = $1
5169
ORDER BY created_at DESC
5270
"""
53-
return await connection.fetch(q, url)
71+
results = await connection.fetch(q, url)
72+
return cls._convert_list_to_dict_if_needed(results, as_dict)
5473

5574
@classmethod
5675
async def get_latest(
57-
cls, url: str | None = None, resource_id: str | None = None
58-
) -> Record | None:
76+
cls, url: str | None = None, resource_id: str | None = None, as_dict: bool = False
77+
) -> Record | dict | None:
5978
column: str = "url" if url else "resource_id"
6079
pool = await context.pool()
6180
async with pool.acquire() as connection:
@@ -66,10 +85,13 @@ async def get_latest(
6685
WHERE catalog.{column} = $1
6786
AND checks.id = catalog.last_check
6887
"""
69-
return await connection.fetchrow(q, url or resource_id)
88+
result = await connection.fetchrow(q, url or resource_id)
89+
return cls._convert_to_dict_if_needed(result, as_dict)
7090

7191
@classmethod
72-
async def get_all(cls, url: str | None = None, resource_id: str | None = None) -> list[Record]:
92+
async def get_all(
93+
cls, url: str | None = None, resource_id: str | None = None, as_dict: bool = False
94+
) -> list[Record] | list[dict]:
7395
column: str = "url" if url else "resource_id"
7496
pool = await context.pool()
7597
async with pool.acquire() as connection:
@@ -81,12 +103,13 @@ async def get_all(cls, url: str | None = None, resource_id: str | None = None) -
81103
AND catalog.{column} = checks.{column}
82104
ORDER BY created_at DESC
83105
"""
84-
return await connection.fetch(q, url or resource_id)
106+
results = await connection.fetch(q, url or resource_id)
107+
return cls._convert_list_to_dict_if_needed(results, as_dict)
85108

86109
@classmethod
87110
async def get_group_by_for_date(
88-
cls, column: str, date: date, page_size: int = 20
89-
) -> list[Record]:
111+
cls, column: str, date: date, page_size: int = 20, as_dict: bool = False
112+
) -> list[Record] | list[dict]:
90113
pool = await context.pool()
91114
async with pool.acquire() as connection:
92115
q = f"""
@@ -97,22 +120,20 @@ async def get_group_by_for_date(
97120
ORDER BY count desc
98121
LIMIT $2
99122
"""
100-
return await connection.fetch(q, date, page_size)
123+
results = await connection.fetch(q, date, page_size)
124+
return cls._convert_list_to_dict_if_needed(results, as_dict)
101125

102126
@classmethod
103-
async def insert(cls, data: dict, returning: str = "id") -> dict:
127+
async def insert(cls, data: dict, returning: str = "id", as_dict: bool = True) -> Record | dict:
104128
"""
105129
Insert a new check in DB, associate it with the resource and return the check dict, optionally associated with the resource dataset_id.
106130
This uses the info from the last check of the same resource.
107-
108-
Note: Returns dict instead of Record because this method performs additional operations beyond simple insertion (joins with catalog table, adds dataset_id).
109131
"""
110132
json_data = convert_dict_values_to_json(data)
111133
q1: str = compute_insert_query(table_name="checks", data=json_data, returning=returning)
112134
pool = await context.pool()
113135
async with pool.acquire() as connection:
114136
last_check: Record = await connection.fetchrow(q1, *json_data.values())
115-
last_check_dict = dict(last_check)
116137
q2 = (
117138
"""UPDATE catalog SET last_check = $1 WHERE resource_id = $2 RETURNING dataset_id"""
118139
)
@@ -121,17 +142,17 @@ async def insert(cls, data: dict, returning: str = "id") -> dict:
121142
)
122143
# Add the dataset_id arg to the check response, if we can, and if it's asked
123144
if returning in ["*", "dataset_id"] and updated_resource:
145+
last_check_dict = dict(last_check)
124146
last_check_dict["dataset_id"] = updated_resource["dataset_id"]
125-
return last_check_dict
147+
return last_check_dict if as_dict else last_check
148+
return dict(last_check) if as_dict else last_check
126149

127150
@classmethod
128-
async def update(cls, check_id: int, data: dict, return_as_dict: bool = False) -> Record | dict:
129-
check: Record = await update_table_record(
151+
async def update(cls, check_id: int, data: dict, as_dict: bool = False) -> Record | dict | None:
152+
check: Record | None = await update_table_record(
130153
table_name="checks", record_id=check_id, data=data
131154
)
132-
if return_as_dict:
133-
return dict(check)
134-
return check
155+
return cls._convert_to_dict_if_needed(check, as_dict)
135156

136157
@classmethod
137158
async def delete(cls, check_id: int) -> None:

udata_hydra/routes/checks.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,22 @@
1616
async def get_latest_check(request: web.Request) -> web.Response:
1717
"""Get the latest check for a given URL or resource_id"""
1818
url, resource_id = get_request_params(request, params_names=["url", "resource_id"])
19-
data: Record | None = await Check.get_latest(url, resource_id)
19+
data: dict | None = await Check.get_latest(url, resource_id, as_dict=True) # type: ignore
2020
if not data:
2121
raise web.HTTPNotFound()
2222
if data["deleted"]:
2323
raise web.HTTPGone()
2424

25-
return web.json_response(CheckSchema().dump(dict(data)))
25+
return web.json_response(CheckSchema().dump(data))
2626

2727

2828
async def get_all_checks(request: web.Request) -> web.Response:
2929
url, resource_id = get_request_params(request, params_names=["url", "resource_id"])
30-
data: list | None = await Check.get_all(url, resource_id)
31-
if not data:
30+
checks: list | None = await Check.get_all(url, resource_id, as_dict=True)
31+
if not checks:
3232
raise web.HTTPNotFound()
3333

34-
return web.json_response([CheckSchema().dump(dict(r)) for r in data])
34+
return web.json_response([CheckSchema().dump(c) for c in checks])
3535

3636

3737
async def get_checks_aggregate(request: web.Request) -> web.Response:
@@ -49,11 +49,11 @@ async def get_checks_aggregate(request: web.Request) -> web.Response:
4949
column: str = request.query.get("group_by")
5050
if not column:
5151
raise web.HTTPBadRequest(text="Missing mandatory 'group_by' param.")
52-
data: list | None = await Check.get_group_by_for_date(column, created_at_date)
53-
if not data:
52+
checks: list | None = await Check.get_group_by_for_date(column, created_at_date, as_dict=True)
53+
if not checks:
5454
raise web.HTTPNotFound()
5555

56-
return web.json_response([CheckGroupBy().dump(dict(r)) for r in data])
56+
return web.json_response([CheckGroupBy().dump(c) for c in checks])
5757

5858

5959
async def create_check(request: web.Request) -> web.Response:
@@ -89,8 +89,8 @@ async def create_check(request: web.Request) -> web.Response:
8989
)
9090
context.monitor().refresh(status)
9191

92-
check: Record | None = await Check.get_latest(url, resource_id)
92+
check: dict | None = await Check.get_latest(url, resource_id, as_dict=True) # type: ignore
9393
if not check:
9494
raise web.HTTPBadRequest(text=f"Check not created, status: {status}")
9595

96-
return web.json_response(CheckSchema().dump(dict(check)), status=201)
96+
return web.json_response(CheckSchema().dump(check), status=201)

0 commit comments

Comments
 (0)