Skip to content

Commit 6859d10

Browse files
refactor: use pydantic instead of marshmalow
1 parent 1793a08 commit 6859d10

File tree

12 files changed

+242
-103
lines changed

12 files changed

+242
-103
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
- Simplify getting Sentry info by loading pyproject.toml info in config [#138](https://github.com/datagouv/hydra/pull/138)
5353
- Add a `POST` `/api/checks/` route for force crawling [#118](https://github.com/datagouv/hydra/pull/118)
5454
- Update `csv-detective` to 0.7.2 which doesn't include yanked version of `requests` anymore [#142](https://github.com/datagouv/hydra/pull/142) and [#144](https://github.com/datagouv/hydra/pull/144)
55+
- Use [Pydantic](https://docs.pydantic.dev/) instead of [Marshmallow](https://marshmallow.readthedocs.io/en/stable/) for API validation [#149](https://github.com/datagouv/hydra/pull/149)
5556

5657
## 1.0.1 (2023-01-04)
5758

poetry.lock

Lines changed: 137 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ coloredlogs = "^15.0.1"
2424
csv-detective = "0.7.2"
2525
dateparser = "^1.1.7"
2626
humanfriendly = "^10.0"
27-
marshmallow = "^3.14.1"
2827
minicli = "^0.5.3"
2928
minio = "7.2.7"
3029
pyarrow = "16.1.0"
30+
pydantic = "^2.8.2"
3131
python-dateutil = "^2.8.2"
3232
python-magic = "^0.4.25"
3333
progressist = "^0.1.0"

udata_hydra/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44

55
import toml
6+
import tomllib
67

78
log = logging.getLogger("udata-hydra")
89

@@ -24,7 +25,8 @@ def configure(self) -> None:
2425
# override with local settings
2526
local_settings = os.environ.get("HYDRA_SETTINGS", Path.cwd() / "config.toml")
2627
if Path(local_settings).exists():
27-
configuration.update(toml.load(local_settings))
28+
with open(Path(local_settings), "rb") as f:
29+
configuration.update(tomllib.load(f))
2830

2931
self.configuration = configuration
3032
self.check()

udata_hydra/analysis/resource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def process_resource(check_id: int, is_first_check: bool) -> None:
4040
4141
Will call udata if first check or changes found, and update check with optional infos
4242
"""
43-
check: dict = await Check.get(check_id)
43+
check: Optional[dict] = await Check.get(check_id)
4444
if not check:
4545
log.error(f"Check not found by id {check_id}")
4646
return

udata_hydra/db/check.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ async def get(cls, check_id: int) -> Optional[dict]:
2020
ON catalog.last_check = checks.id
2121
WHERE checks.id = $1;
2222
"""
23-
return await connection.fetchrow(q, check_id)
23+
record = await connection.fetchrow(q, check_id)
24+
if record:
25+
return dict(record)
26+
return None
2427

2528
@classmethod
2629
async def get_latest(
@@ -36,7 +39,10 @@ async def get_latest(
3639
WHERE checks.id = catalog.last_check
3740
AND catalog.{column} = $1
3841
"""
39-
return await connection.fetchrow(q, url or resource_id)
42+
record = await connection.fetchrow(q, url or resource_id)
43+
if record:
44+
return dict(record)
45+
return None
4046

4147
@classmethod
4248
async def get_all(

udata_hydra/db/resource.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
from typing import Optional
2+
13
from udata_hydra import context
24

35

46
class Resource:
57
"""Represents a resource in the "catalog" DB table"""
68

79
@classmethod
8-
async def get(cls, resource_id: str, column_name: str = "*") -> dict:
10+
async def get(cls, resource_id: str, column_name: str = "*") -> Optional[dict]:
911
pool = await context.pool()
1012
async with pool.acquire() as connection:
1113
q = f"""SELECT {column_name} FROM catalog WHERE resource_id = '{resource_id}';"""
12-
resource = await connection.fetchrow(q)
13-
return resource
14+
record = await connection.fetchrow(q)
15+
if record:
16+
return dict(record)
17+
return None
1418

1519
@classmethod
1620
async def insert(

udata_hydra/routes/checks.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import aiohttp
55
from aiohttp import web
6-
from marshmallow import ValidationError
6+
from pydantic import ValidationError
77

88
from udata_hydra import config, context
99
from udata_hydra.crawl import check_url
@@ -16,20 +16,22 @@
1616
async def get_latest_check(request: web.Request) -> web.Response:
1717
"""Get the latest check for a given URL or resource_id"""
1818
url, resource_id = get_request_params(request, params_names=["url", "resource_id"])
19-
data: Optional[dict] = await Check.get_latest(url, resource_id)
20-
if not data:
19+
20+
check: Optional[dict] = await Check.get_latest(url, resource_id)
21+
if not check:
2122
raise web.HTTPNotFound()
22-
if data["deleted"]:
23+
if check["deleted"]:
2324
raise web.HTTPGone()
24-
return web.json_response(CheckSchema().dump(dict(data)))
25+
26+
return web.Response(text=json.dumps(check, default=str), content_type="application/json")
2527

2628

2729
async def get_all_checks(request: web.Request) -> web.Response:
2830
url, resource_id = get_request_params(request, params_names=["url", "resource_id"])
2931
data: Optional[list] = await Check.get_all(url, resource_id)
3032
if not data:
3133
raise web.HTTPNotFound()
32-
return web.json_response([CheckSchema().dump(dict(r)) for r in data])
34+
return web.json_response([r for r in data])
3335

3436

3537
async def create_check(request: web.Request) -> web.Response:
@@ -40,7 +42,7 @@ async def create_check(request: web.Request) -> web.Response:
4042
payload: dict = await request.json()
4143
resource_id: str = payload["resource_id"]
4244
except ValidationError as err:
43-
raise web.HTTPBadRequest(text=json.dumps(err.messages))
45+
raise web.HTTPBadRequest(text=err.json())
4446
except KeyError as e:
4547
raise web.HTTPBadRequest(text=f"Missing key: {e}")
4648

@@ -65,4 +67,4 @@ async def create_check(request: web.Request) -> web.Response:
6567
if not check:
6668
raise web.HTTPBadRequest(text=f"Check not created, status: {status}")
6769

68-
return web.json_response(CheckSchema().dump(dict(check)))
70+
return web.Response(text=json.dumps(check, default=str), content_type="application/json")

udata_hydra/routes/resources.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import json
2+
from typing import Optional
23

34
from aiohttp import web
4-
from marshmallow import ValidationError
5+
from pydantic import ValidationError
56

67
from udata_hydra.db.resource import Resource
78
from udata_hydra.schemas import ResourceSchema
@@ -14,11 +15,11 @@ async def get_resource(request: web.Request) -> web.Response:
1415
If resource is not found, respond with a 404 status code
1516
"""
1617
[resource_id] = get_request_params(request, params_names=["resource_id"])
17-
resource: dict = await Resource.get(resource_id)
18+
resource: Optional[dict] = await Resource.get(resource_id)
1819
if not resource:
1920
raise web.HTTPNotFound()
2021

21-
return web.json_response(ResourceSchema().dump(dict(resource)))
22+
return web.Response(text=json.dumps(resource, default=str), content_type="application/json")
2223

2324

2425
async def create_resource(request: web.Request) -> web.Response:
@@ -29,21 +30,21 @@ async def create_resource(request: web.Request) -> web.Response:
2930
"""
3031
try:
3132
payload = await request.json()
32-
valid_payload: dict = ResourceSchema().load(payload)
33+
valid_payload = ResourceSchema.model_validate(payload)
3334
except ValidationError as err:
34-
raise web.HTTPBadRequest(text=json.dumps(err.messages))
35+
raise web.HTTPBadRequest(text=err.json())
3536

36-
resource: dict = valid_payload["document"]
37-
if not resource:
37+
document = valid_payload.document
38+
if not document:
3839
raise web.HTTPBadRequest(text="Missing document body")
3940

40-
dataset_id = valid_payload["dataset_id"]
41-
resource_id = valid_payload["resource_id"]
41+
dataset_id = valid_payload.dataset_id
42+
resource_id = valid_payload.resource_id
4243

4344
await Resource.insert(
4445
dataset_id=dataset_id,
45-
resource_id=resource_id,
46-
url=resource["url"],
46+
resource_id=str(resource_id),
47+
url=document.url,
4748
priority=True,
4849
)
4950

@@ -58,30 +59,30 @@ async def update_resource(request: web.Request) -> web.Response:
5859
"""
5960
try:
6061
payload = await request.json()
61-
valid_payload: dict = ResourceSchema().load(payload)
62+
valid_payload = ResourceSchema.model_validate(payload)
6263
except ValidationError as err:
63-
raise web.HTTPBadRequest(text=json.dumps(err.messages))
64+
raise web.HTTPBadRequest(text=err.json())
6465

65-
resource: dict = valid_payload["document"]
66-
if not resource:
66+
document = valid_payload.document
67+
if not document:
6768
raise web.HTTPBadRequest(text="Missing document body")
6869

69-
dataset_id: str = valid_payload["dataset_id"]
70-
resource_id: str = valid_payload["resource_id"]
70+
dataset_id: str = valid_payload.dataset_id
71+
resource_id: str = str(valid_payload.resource_id)
7172

72-
await Resource.update_or_insert(dataset_id, resource_id, resource["url"])
73+
await Resource.update_or_insert(dataset_id, resource_id, document.url)
7374

7475
return web.json_response({"message": "updated"})
7576

7677

7778
async def delete_resource(request: web.Request) -> web.Response:
7879
try:
7980
payload = await request.json()
80-
valid_payload: dict = ResourceSchema().load(payload)
81+
valid_payload = ResourceSchema.model_validate(payload)
8182
except ValidationError as err:
82-
raise web.HTTPBadRequest(text=json.dumps(err.messages))
83+
raise web.HTTPBadRequest(text=err.json())
8384

84-
resource_id: str = valid_payload["resource_id"]
85+
resource_id: str = str(valid_payload.resource_id)
8586

8687
pool = request.app["pool"]
8788
async with pool.acquire() as connection:

udata_hydra/schemas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .check import CheckSchema
2-
from .resource import ResourceSchema
2+
from .resource import ResourceDocumentSchema, ResourceSchema

0 commit comments

Comments
 (0)