Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions back/scripts/workflow/data_warehouse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

import pandas as pd
import polars as pl
from sqlalchemy import text

Expand All @@ -19,6 +20,7 @@ def __init__(self, config: dict):
self.config = config
self.warehouse_folder = Path(self.config["warehouse"]["data_folder"])
self.warehouse_folder.mkdir(exist_ok=True, parents=True)
self.chunksize = 10000

self.send_to_db = {
"collectivites": CommunitiesEnricher.get_output_path(config),
Expand Down Expand Up @@ -50,7 +52,7 @@ def _send_to_postgres(self):
# or keep the same schema.
if_table_exists = "replace" if self.config["workflow"]["replace_tables"] else "append"

with connector.engine.connect() as conn:
with connector.engine.begin() as conn:
for table_name, filename in self.send_to_db.items():
df = pl.read_parquet(filename)

Expand All @@ -59,7 +61,50 @@ def _send_to_postgres(self):
f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name= '{table_name}')"
)
table_exists = conn.execute(table_exists_query).scalar()

if table_exists:
conn.execute(text(f"TRUNCATE {table_name}"))
self.add_missing_columns_to_sql_table(conn, table_name, df)
df.to_pandas().to_sql(
table_name, conn, if_exists=if_table_exists, chunksize=self.chunksize
)

@staticmethod
def add_missing_columns_to_sql_table(conn, table_name: str, df: pl.DataFrame):
"""Ajoute les colonnes manquantes dans la table SQL à partir du DataFrame Polars."""

schema = df.schema
columns_sql = conn.execute(
text(f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = '{table_name}'
""")
).fetchall()
existing_cols = {col[0] for col in columns_sql}

missing_cols = schema.keys() - existing_cols
if not missing_cols:
return

# Mapping Polars -> SQL
type_mapping = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Il y a un bout de code qui a l'air de faire l'inverse dans Historisateur.convert_types_for_sql.
Ce dictionnaire pourrait être utile ailleurs dans le code, je le sortirai de la méthode pour en créer une constante.

pl.Int64: "BIGINT",
pl.Int32: "INTEGER",
pl.Float64: "DOUBLE PRECISION",
pl.Float32: "REAL",
pl.Boolean: "BOOLEAN",
pl.Utf8: "TEXT",
pl.Date: "DATE",
pl.Datetime: "TIMESTAMP",
}

if missing_cols:
add_columns = []
for col in missing_cols:
pl_type = schema[col]
sql_type = type_mapping.get(pl_type, "TEXT")
add_columns.append(f'ADD COLUMN "{col}" {sql_type}')

df.write_database(table_name, conn, if_table_exists=if_table_exists)
alter_query = f'ALTER TABLE "{table_name}" {", ".join(add_columns)};'
conn.execute(text(alter_query))
conn.commit()