Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions y/_db/utils/stringify.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timezone
from decimal import Decimal
from typing import Any, Final, Iterable
from typing import Any, Callable, Final, Iterable


UTC: Final = timezone.utc
Expand All @@ -9,7 +9,7 @@
isoformat: Final = datetime.isoformat


def stringify_column_value(value: Any, provider: str) -> str:
def stringify_column_value(value: Any, format_bytes: Callable[[bytes], str]) -> str:
"""
Convert a Python value to a string representation suitable for SQL insertion.

Expand Down Expand Up @@ -48,12 +48,7 @@ def stringify_column_value(value: Any, provider: str) -> str:
if value is None:
return "null"
elif isinstance(value, bytes):
if provider == "postgres":
# return f"E'\\x{value.hex()}"
return f"'{value.decode()}'::bytea"
elif provider == "sqlite":
return f"X'{value.hex()}'"
raise NotImplementedError(provider)
return format_bytes(value)
elif isinstance(value, str):
return f"'{value}'"
elif isinstance(value, (int, Decimal)):
Expand All @@ -64,7 +59,16 @@ def stringify_column_value(value: Any, provider: str) -> str:
raise NotImplementedError(type(value), value)


def build_row(row: Iterable[Any], provider: str) -> str:
def _format_bytes_sqlite(b: bytes) -> str:
return f"X'{value.hex()}'"


def _format_bytes_postgres(b: bytes) -> str:
# old, but save for later: return f"E'\\x{value.hex()}"
return f"'{value.decode()}'::bytea"


def build_row(row: Iterable[Any], format_bytes: Callable[[bytes], str]) -> str:
"""
Build a SQL row string from an iterable of values.

Expand All @@ -88,13 +92,20 @@ def build_row(row: Iterable[Any], provider: str) -> str:
See Also:
- :func:`stringify_column_value`
"""
return f"({','.join(stringify_column_value(col, provider) for col in row)})"
return f"({','.join(stringify_column_value(col, format_bytes) for col in row)})"


def build_query(
provider_name: str, entity_name: str, columns: Iterable[str], items: Iterable[Any]
) -> str:
data = ",".join(build_row(i, provider_name) for i in items)
if provider_name == "postgres":
format_bytes = _format_bytes_postgres
elif provider_name == "sqlite":
format_bytes = _format_bytes_sqlite
else:
raise NotImplementedError(provider)

data = ",".join(build_row(i, format_bytes) for i in items)
if provider_name == "sqlite":
return f'insert or ignore into {entity_name} ({",".join(columns)}) values {data}'
elif provider_name == "postgres":
Expand Down