Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ docs/temp/*
*.swp

# tool outputs
df.py
ssg.py
orm.yaml
src-stats.yaml
config.yaml
/df.py
/ssg.py
/orm.yaml
/src-stats.yaml
/config.yaml
*.yaml.gz

*_stories.py
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
FROM python:3.13.3-alpine3.22
RUN apk add bash poetry
RUN apk add bash poetry build-base
WORKDIR /app
ADD . /app
RUN mkdir /pypoetry
Expand Down
36 changes: 35 additions & 1 deletion datafaker/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from sqlalchemy import Connection, insert, inspect
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema, MetaData, Table
from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table

from datafaker.base import FileUploader, TableGenerator
from datafaker.settings import get_settings
Expand All @@ -24,6 +25,39 @@
RowCounts = Counter[str]


@compiles(CreateColumn, "duckdb")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh this is fun

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty nasty actually. But yes, fun that this hook exists!

def remove_serial(element: CreateColumn, compiler: Any, **kw: Any) -> str:
"""
Intercede in compilation for column creation, removing PostgreSQL's ``SERIAL``.

DuckDB does not understand ``SERIAL``, and we don't care about
autoincrementing in datafaker. Ideally ``duckdb_engine`` would remove
this for us, or DuckDB would implement ``SERIAL``
:param element: The CreateColumn being executed.
:param compiler: Actually a DDLCompiler, but that type is not exported.
:param kw: Further arguments.
:return: Corrected SQL.
"""
text: str = compiler.visit_create_column(element, **kw)
return text.replace(" SERIAL ", " INTEGER ")


@compiles(CreateTable, "duckdb")
def remove_on_delete_cascade(element: CreateTable, compiler: Any, **kw: Any) -> str:
"""
Intercede in compilation for column creation, removing ``ON DELETE CASCADE``.

DuckDB does not understand cascades, and we don't care about
that in datafaker. Ideally ``duckdb_engine`` would remove this for us.
:param element: The CreateTable being executed.
:param compiler: Actually a DDLCompiler, but that type is not exported.
:param kw: Further arguments.
:return: Corrected SQL.
"""
text: str = compiler.visit_create_table(element, **kw)
return text.replace(" ON DELETE CASCADE", "")


def create_db_tables(metadata: MetaData) -> None:
"""Create tables described by the sqlalchemy metadata object."""
settings = get_settings()
Expand Down
206 changes: 177 additions & 29 deletions datafaker/dump.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,186 @@
"""Data dumping functions."""
import csv
import io
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from pathlib import Path

import pandas as pd
import sqlalchemy
from sqlalchemy.schema import MetaData

from datafaker.utils import create_db_engine, get_sync_engine, logger

if TYPE_CHECKING:
from _csv import Writer


def _make_csv_writer(file: io.TextIOBase) -> "Writer":
"""Make the standard CSV file writer."""
return csv.writer(file, quoting=csv.QUOTE_MINIMAL)


def dump_db_tables(
metadata: MetaData,
dsn: str,
schema: str | None,
table_name: str,
file: io.TextIOBase,
) -> None:
"""Output the table as CSV."""
if table_name not in metadata.tables:
logger.error("%s is not a table described in the ORM file", table_name)
return
table = metadata.tables[table_name]
csv_out = _make_csv_writer(file)
csv_out.writerow(table.columns.keys())
engine = get_sync_engine(create_db_engine(dsn, schema_name=schema))
with engine.connect() as connection:
result = connection.execute(sqlalchemy.select(table))
for row in result:
csv_out.writerow(row)

class TableWriter(ABC):
"""Writes a table out to a file."""

EXTENSION = ".csv"

def __init__(self, metadata: MetaData, dsn: str, schema: str | None) -> None:
"""
Initialize the TableWriter.

:param metadata: The metadata for our database.
:param dsn: The connection string for our database.
:param schema: The schema name for our database, or None for the default.
"""
self._metadata = metadata
self._dsn = dsn
self._schema = schema

def connect(self) -> sqlalchemy.engine.Connection:
"""Connect to the database."""
engine = get_sync_engine(create_db_engine(self._dsn, schema_name=self._schema))
return engine.connect()

@abstractmethod
def write_file(self, table: sqlalchemy.Table, filepath: Path) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param dir: The directory to write into.
:return: ``true`` on success, otherwise ``false``.
"""

def write(self, table: sqlalchemy.Table, directory: Path) -> bool:
"""
Write the table into a directory with a filename based on the table's name.

:param table: The table to write out.
:param directory: The directory to write the table into.
:return: ``true`` on success, otherwise ``false``.
"""
tn = table.name
# DuckDB tables derived from files have confusing suffixes
# that we should probably remove
tn = tn.removesuffix(".csv")
tn = tn.removesuffix(".parquet")
return self.write_file(table, directory / f"{tn}{self.EXTENSION}")


class ParquetTableWriter(TableWriter):
"""Writes the table to a Parquet file."""

EXTENSION = ".parquet"

def write_file(self, table: sqlalchemy.Table, filepath: Path) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``true`` on success, otherwise ``false``.
"""
with self.connect() as connection:
dates = [
str(name)
for name, col in table.columns.items()
if isinstance(
col.type,
(
sqlalchemy.types.DATE,
sqlalchemy.types.DATETIME,
sqlalchemy.types.TIMESTAMP,
),
)
]
df = pd.read_sql(
sql=f"SELECT * FROM {table.name}",
con=connection,
columns=[str(col.name) for col in table.columns.values()],
parse_dates=dates,
)
df.to_parquet(filepath)
return True


class DuckDbParquetTableWriter(ParquetTableWriter):
"""
Writes the table to a Parquet file using DuckDB SQL.

The Pandas method used by ParquetTableWriter currently
does not work with DuckDB.
"""

def write_file(self, table: sqlalchemy.Table, filepath: Path) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``true`` on success, otherwise ``false``.
"""
with self.connect() as connection:
result = connection.execute(
sqlalchemy.text(
# We need the double quotes to get DuckDB to read the table not the file.
f"COPY \"{table.name}\" TO '{filepath}' (FORMAT PARQUET)"
)
)
return result is not None


def get_parquet_table_writer(
metadata: MetaData, dsn: str, schema: str | None
) -> TableWriter:
"""
Get a ``TableWriter`` that writes parquet files.

:param metadata: The database metadata containing the tables to be dumped to files.
:param dsn: The database connection string.
:param schema: The schema name, if required.
:return: ``TableWriter`` to write a parquet file.
"""
if dsn.startswith("duckdb:"):
return DuckDbParquetTableWriter(metadata, dsn, schema)
return ParquetTableWriter(metadata, dsn, schema)


class TableWriterIO(TableWriter):
"""Writes the table to an output object."""

@abstractmethod
def write_io(self, table: sqlalchemy.Table, out: io.TextIOBase) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``true`` on success, otherwise ``false``.
"""

def write_file(self, table: sqlalchemy.Table, filepath: Path) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``true`` on success, otherwise ``false``.
"""
with open(filepath, "wt", newline="", encoding="utf-8") as out:
return self.write_io(table, out)


class CsvTableWriter(TableWriterIO):
"""Writes the table to a CSV file."""

def write_io(self, table: sqlalchemy.Table, out: io.TextIOBase) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``True`` on success, otherwise ``False``.
"""
if table.name not in self._metadata.tables:
logger.error("%s is not a table described in the ORM file", table.name)
return False
table = self._metadata.tables[table.name]
csv_out = csv.writer(out, quoting=csv.QUOTE_MINIMAL)
csv_out.writerow(table.columns.keys())
with self.connect() as connection:
result = connection.execute(sqlalchemy.select(table))
for row in result:
csv_out.writerow(row)
return True
17 changes: 13 additions & 4 deletions datafaker/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,18 @@ def __init__(
)
self.buckets: Sequence[int] = [0] * 10
for rb in raw_buckets:
if rb.b is not None:
bucket = min(9, max(0, int(rb.b) + 1))
self.buckets[bucket] += rb.f / count
try:
x = float(rb.b)
if x.is_integer():
bucket = min(9, max(0, int(x) + 1))
self.buckets[bucket] += rb.f / count
except TypeError:
# We get a type error if there are no rows returned at all
# because rb.b is None in this case.
# We could just test for None explicitly, but this way
# catches errors if SQLAlchemy returns something that
# isn't a number for some other unknown reason.
pass
Comment on lines 299 to 305

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh when are we expecting this to happen, and if so do we want to log it?

self.mean = mean
self.stddev = stddev

Expand Down Expand Up @@ -406,7 +415,7 @@ class ConstantGeneratorFactory(GeneratorFactory):
"""Just the null generator."""

def get_generators(
self, columns: list[Column], engine: Engine
self, columns: list[Column], _engine: Engine
) -> Sequence[Generator]:
"""Get the generators appropriate for these columns."""
if len(columns) != 1:
Expand Down
32 changes: 20 additions & 12 deletions datafaker/generators/mimesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,17 @@ class MimesisStringGeneratorFactory(GeneratorFactory):
"text.word",
]

def _get_generators_with(
self, gen_class: Callable, **kwargs: Any
) -> list[Generator]:
gens: list[Generator] = []
for name in self.GENERATOR_NAMES:
try:
gens.append(gen_class(name, **kwargs))
except: # pylint: disable=bare-except
pass
return gens

def get_generators(
self, columns: list[Column], engine: Engine
) -> Sequence[Generator]:
Expand All @@ -317,19 +328,16 @@ def get_generators(
fitness_fn = None
length = column_type.length
if length:
return list(
map(
lambda gen: MimesisGeneratorTruncated(
gen, length, fitness_fn, buckets
),
self.GENERATOR_NAMES,
)
)
return list(
map(
lambda gen: MimesisGenerator(gen, fitness_fn, buckets),
self.GENERATOR_NAMES,
return self._get_generators_with(
MimesisGeneratorTruncated,
length=length,
value_fn=fitness_fn,
buckets=buckets,
)
return self._get_generators_with(
MimesisGenerator,
value_fn=fitness_fn,
buckets=buckets,
)


Expand Down
6 changes: 3 additions & 3 deletions datafaker/interactive/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _remove_prefix_src_stats(self, prefix: str) -> list[MutableMapping[str, Any]
self.config["src-stats"] = new_src_stats
return new_src_stats

def get_nonnull_columns(self, table_name: str) -> list[str]:
def get_nullable_columns(self, table_name: str) -> list[str]:
"""Get the names of the nullable columns in the named table."""
metadata_table = self.metadata.tables[table_name]
return [
Expand Down Expand Up @@ -327,8 +327,8 @@ def do_counts(self, _arg: str) -> None:
if len(self._table_entries) <= self.table_index:
return
table_name = self.table_name()
nonnull_columns = self.get_nonnull_columns(table_name)
colcounts = [f', COUNT("{nnc}") AS "{nnc}"' for nnc in nonnull_columns]
nullable_columns = self.get_nullable_columns(table_name)
colcounts = [f', COUNT("{nnc}") AS "{nnc}"' for nnc in nullable_columns]
with self.sync_engine.connect() as connection:
result = (
connection.execute(
Expand Down
Loading