diff --git a/back/scripts/workflow/data_warehouse.py b/back/scripts/workflow/data_warehouse.py index 6aff0d159..04f0a6762 100644 --- a/back/scripts/workflow/data_warehouse.py +++ b/back/scripts/workflow/data_warehouse.py @@ -1,5 +1,6 @@ from pathlib import Path +import pandas as pd import polars as pl from sqlalchemy import text @@ -19,6 +20,7 @@ def __init__(self, config: dict): self.config = config self.warehouse_folder = Path(self.config["warehouse"]["data_folder"]) self.warehouse_folder.mkdir(exist_ok=True, parents=True) + self.chunksize = 10000 self.send_to_db = { "collectivites": CommunitiesEnricher.get_output_path(config), @@ -50,7 +52,7 @@ def _send_to_postgres(self): # or keep the same schema. if_table_exists = "replace" if self.config["workflow"]["replace_tables"] else "append" - with connector.engine.connect() as conn: + with connector.engine.begin() as conn: for table_name, filename in self.send_to_db.items(): df = pl.read_parquet(filename) @@ -59,7 +61,50 @@ def _send_to_postgres(self): f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name= '{table_name}')" ) table_exists = conn.execute(table_exists_query).scalar() + if table_exists: conn.execute(text(f"TRUNCATE {table_name}")) + self.add_missing_columns_to_sql_table(conn, table_name, df) + df.to_pandas().to_sql( + table_name, conn, if_exists=if_table_exists, chunksize=self.chunksize + ) + + @staticmethod + def add_missing_columns_to_sql_table(conn, table_name: str, df: pl.DataFrame): + """Ajoute les colonnes manquantes dans la table SQL à partir du DataFrame Polars.""" + + schema = df.schema + columns_sql = conn.execute( + text(f""" + SELECT column_name FROM information_schema.columns + WHERE table_name = '{table_name}' + """) + ).fetchall() + existing_cols = {col[0] for col in columns_sql} + + missing_cols = schema.keys() - existing_cols + if not missing_cols: + return + + # Mapping Polars -> SQL + type_mapping = { + pl.Int64: "BIGINT", + pl.Int32: "INTEGER", + pl.Float64: "DOUBLE PRECISION", + pl.Float32: "REAL", + pl.Boolean: "BOOLEAN", + pl.Utf8: "TEXT", + pl.Date: "DATE", + pl.Datetime: "TIMESTAMP", + } + + if missing_cols: + add_columns = [] + for col in missing_cols: + pl_type = schema[col] + sql_type = type_mapping.get(pl_type, "TEXT") + add_columns.append(f'ADD COLUMN "{col}" {sql_type}') - df.write_database(table_name, conn, if_table_exists=if_table_exists) + alter_query = f'ALTER TABLE "{table_name}" {", ".join(add_columns)};' + conn.execute(text(alter_query)) + conn.commit()