Skip to content

Commit 5b4f8c4

Browse files
authored
fix: send correct username to consumer trino (#6629)
* fix: send correct username to consumer trino * Apply suggestion from @ravenac95
1 parent f6bd65c commit 5b4f8c4

File tree

9 files changed

+33
-12
lines changed

9 files changed

+33
-12
lines changed

warehouse/scheduler/scheduler/dlt_destination.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ class DLTDestinationResource(abc.ABC):
1414

1515
@abc.abstractmethod
1616
@asynccontextmanager
17-
def get_destination(self, *, dataset_schema: str) -> t.AsyncIterator[Destination]:
17+
def get_destination(
18+
self, *, dataset_schema: str, user: str | None = None
19+
) -> t.AsyncIterator[Destination]:
1820
"""Get a DLT destination configured for the given dataset schema."""
1921
...
2022

@@ -27,7 +29,7 @@ def __init__(self, *, database_path: str) -> None:
2729

2830
@asynccontextmanager
2931
async def get_destination(
30-
self, *, dataset_schema: str
32+
self, *, dataset_schema: str, user: str | None = None
3133
) -> t.AsyncIterator[Destination]:
3234
from dlt.destinations import duckdb
3335
from dlt.destinations.impl.duckdb.configuration import DuckDbCredentials
@@ -49,7 +51,7 @@ def __init__(self, *, trino: TrinoResource, catalog: str) -> None:
4951

5052
@asynccontextmanager
5153
async def get_destination(
52-
self, *, dataset_schema: str
54+
self, *, dataset_schema: str, user: str | None = None
5355
) -> t.AsyncIterator[Destination]:
5456
from dlt.destinations import sqlalchemy
5557

@@ -58,7 +60,7 @@ async def get_destination(
5860
await cursor.execute(
5961
f'CREATE SCHEMA IF NOT EXISTS "{self._catalog}"."{dataset_schema}"'
6062
)
61-
user = conn.user or "scheduler"
63+
user = user or conn.user or "scheduler"
6264
credentials = f"trinoso://{user}@{conn.host}:{conn.port}/{self._catalog}/{dataset_schema}"
6365
yield sqlalchemy(
6466
credentials=credentials,

warehouse/scheduler/scheduler/mq/handlers/data_ingestion.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
HandlerResponse,
1717
RunContext,
1818
)
19+
from scheduler.utils import get_warehouse_user
1920

2021

2122
class DataIngestionRunRequestHandler(RunHandler[DataIngestionRunRequest]):
@@ -106,6 +107,9 @@ async def handle_run_message(
106107
message=f"Unsupported type: {config.factory_type}",
107108
)
108109

110+
warehouse_user = get_warehouse_user(
111+
user_type="rw", org_id=org_id, org_name=context.organization.name
112+
)
109113
async with context.step_context(
110114
name="execute_data_ingestion_pipeline",
111115
display_name="Execute Data Ingestion Pipeline",
@@ -116,6 +120,7 @@ async def handle_run_message(
116120
config=config.config,
117121
dataset_id=dataset_id,
118122
org_id=org_id,
123+
destination_user=warehouse_user,
119124
dlt_destination=dlt_destination,
120125
common_settings=common_settings,
121126
)

warehouse/scheduler/scheduler/mq/handlers/data_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
StepContext,
2626
SuccessResponse,
2727
)
28-
from scheduler.utils import OSOClientTableResolver, ctas_query, get_trino_user
28+
from scheduler.utils import OSOClientTableResolver, ctas_query, get_warehouse_user
2929
from sqlglot import exp
3030
from sqlmesh import EngineAdapter
3131

@@ -125,7 +125,7 @@ async def handle_run_message(
125125
dataset_name = dataset.node.name
126126
org_name = dataset.node.organization.name
127127

128-
user = get_trino_user("rw", context.organization.id, org_name)
128+
user = get_warehouse_user("rw", context.organization.id, org_name)
129129

130130
assert isinstance(
131131
data_model_def,

warehouse/scheduler/scheduler/mq/handlers/ingestion/archive.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ async def execute(
5353
config: dict[str, object],
5454
dataset_id: str,
5555
org_id: str,
56+
destination_user: str,
5657
dlt_destination: DLTDestinationResource,
5758
common_settings: CommonSettings,
5859
) -> HandlerResponse:
@@ -129,7 +130,8 @@ async def execute(
129130

130131
try:
131132
async with dlt_destination.get_destination(
132-
dataset_schema=dataset_schema
133+
dataset_schema=dataset_schema,
134+
user=destination_user,
133135
) as destination:
134136
dlt_pipeline = pipeline(
135137
pipeline_name=pipeline_name,

warehouse/scheduler/scheduler/mq/handlers/ingestion/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ async def execute(
1919
config: dict[str, object],
2020
dataset_id: str,
2121
org_id: str,
22+
destination_user: str,
2223
dlt_destination: DLTDestinationResource,
2324
common_settings: CommonSettings,
2425
) -> HandlerResponse:

warehouse/scheduler/scheduler/mq/handlers/ingestion/rest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ async def execute(
3333
config: dict[str, object],
3434
dataset_id: str,
3535
org_id: str,
36+
destination_user: str,
3637
dlt_destination: DLTDestinationResource,
3738
common_settings: CommonSettings,
3839
) -> HandlerResponse:
@@ -83,7 +84,8 @@ async def execute(
8384
dataset_schema = placeholder_target_table.db
8485
pipeline_name = f"{org_id}_{dataset_id}".replace("-", "")[:50]
8586
async with dlt_destination.get_destination(
86-
dataset_schema=dataset_schema
87+
dataset_schema=dataset_schema,
88+
user=destination_user,
8789
) as destination:
8890
p = pipeline(
8991
pipeline_name=pipeline_name,

warehouse/scheduler/scheduler/mq/handlers/static_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
SuccessResponse,
1717
TableReference,
1818
)
19-
from scheduler.utils import dlt_to_oso_schema
19+
from scheduler.utils import dlt_to_oso_schema, get_warehouse_user
2020

2121
if t.TYPE_CHECKING:
2222
from scheduler.config import CommonSettings
@@ -99,8 +99,15 @@ async def handle_run_message(
9999
},
100100
)
101101

102+
warehouse_user = get_warehouse_user(
103+
user_type="rw",
104+
org_id=org_id,
105+
org_name=context.organization.name,
106+
)
107+
102108
async with dlt_destination.get_destination(
103109
dataset_schema=schema_name,
110+
user=warehouse_user,
104111
) as dlt_destination_instance:
105112
for model_id in message.model_ids:
106113
async with context.step_context(

warehouse/scheduler/scheduler/mq/handlers/sync_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
RunContext,
1919
SuccessResponse,
2020
)
21-
from scheduler.utils import get_trino_user
21+
from scheduler.utils import get_warehouse_user
2222

2323
EXCLUDED_SCHEMAS = {"information_schema"}
2424
MAX_CONCURRENT_QUERIES = 5 # Number of concurrent queries to run
@@ -139,7 +139,7 @@ async def handle_run_message(
139139
extra={"catalog_name": catalog_name, "org_id": organization.id},
140140
)
141141

142-
user = get_trino_user("ro", organization.id, organization.name)
142+
user = get_warehouse_user("ro", organization.id, organization.name)
143143

144144
# Use a single Trino client for all queries
145145
async with consumer_trino.async_get_client(user=user) as client:

warehouse/scheduler/scheduler/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def convert_uuid_bytes_to_str(uuid_bytes: bytes) -> str:
1818
return str(uuid.UUID(bytes=uuid_bytes))
1919

2020

21-
def get_trino_user(user_type: Literal["rw", "ro"], org_id: str, org_name: str) -> str:
21+
def get_warehouse_user(
22+
user_type: Literal["rw", "ro"], org_id: str, org_name: str
23+
) -> str:
2224
return f"{user_type}-{org_name.strip().lower()}-{org_id.replace('-', '').lower()}"
2325

2426

0 commit comments

Comments
 (0)