Skip to content

Commit 05ca5f3

Browse files
authored
Yev/text to sql v1.1 (#653)
minor update to simplify code and use xml tags --------- Signed-off-by: Yev Meyer <ymeyer@nvidia.com>
1 parent 24c10a0 commit 05ca5f3

File tree

4 files changed

+15
-34
lines changed

4 files changed

+15
-34
lines changed

resources_servers/text_to_sql/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Each data sample should include:
4141
},
4242
{
4343
"role": "user",
44-
"content": "DIALECT: postgresql\n\nDATABASE CONTEXT:\nCREATE TABLE users (id SERIAL PRIMARY KEY, name VARCHAR(100));\nINSERT INTO users VALUES (1, 'Alice'), (2, 'Bob');\n\nQUESTION:\nList all user names ordered alphabetically"
44+
"content": "<DIALECT>postgresql</DIALECT>\n\n<DATABASE_CONTEXT>\nCREATE TABLE users (id SERIAL PRIMARY KEY, name VARCHAR(100));\nINSERT INTO users VALUES (1, 'Alice'), (2, 'Bob');\n</DATABASE_CONTEXT>\n\n<QUESTION>\nList all user names ordered alphabetically\n</QUESTION>"
4545
}
4646
]
4747
},

resources_servers/text_to_sql/app.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,25 +97,6 @@ def extract_sql_from_response(text: str) -> Optional[str]:
9797
return None
9898

9999

100-
def _extract_question_text(params: NeMoGymResponseCreateParamsNonStreaming) -> str:
101-
"""Extract the question text from the last user message."""
102-
last_text: Optional[str] = None
103-
for m in params.input or []:
104-
if getattr(m, "role", None) == "user":
105-
c = getattr(m, "content", None)
106-
if isinstance(c, str):
107-
last_text = c
108-
return (last_text or "").strip()
109-
110-
111-
def _extract_dialect_from_prompt(text: str) -> Optional[str]:
112-
"""Extract SQL dialect from a structured prompt."""
113-
if not text:
114-
return None
115-
match = re.search(r"^\s*DIALECT:\s*([a-zA-Z0-9_+-]+)\s*$", text, re.MULTILINE)
116-
return match.group(1) if match else None
117-
118-
119100
def _normalize_dialect(dialect: Optional[str]) -> Optional[str]:
120101
if not dialect:
121102
return None
@@ -171,7 +152,7 @@ class TextToSqlRunRequest(BaseRunRequest):
171152
sql: str # Ground truth SQL query (required)
172153
sql_dialect: str # SQL dialect: mysql, postgresql, sqlite (required)
173154
sql_context: str = "" # Database schema (CREATE/INSERT statements)
174-
sql_prompt: Optional[str] = None # Natural language question (optional, extracted from input if not provided)
155+
sql_prompt: str # Natural language question (required)
175156
metadata: Optional[dict[str, Any]] = None
176157

177158

@@ -196,7 +177,7 @@ class TextToSqlVerifyResponse(BaseVerifyResponse):
196177
extracted_sql: Optional[str] = None
197178
sql_dialect: str # SQL dialect used
198179
sql_context: str # Database schema provided
199-
sql_prompt: Optional[str] = None # May be extracted from input
180+
sql_prompt: str # Natural language question
200181
judge_passed: bool = False
201182
failure_reason: Optional[FailureCode] = None
202183
judge_evaluations: list[JudgeEvaluation] = []
@@ -235,8 +216,8 @@ async def verify(self, body: TextToSqlVerifyRequest) -> TextToSqlVerifyResponse:
235216
if sql_dialect not in SUPPORTED_DIALECTS:
236217
raise ValueError(f"Unsupported SQL dialect '{sql_dialect}'. Supported: {sorted(SUPPORTED_DIALECTS)}")
237218

238-
# Extract question from request field or from user message
239-
sql_prompt = body.sql_prompt or _extract_question_text(body.responses_create_params)
219+
# sql_prompt is a required field, validated by Pydantic
220+
sql_prompt = body.sql_prompt
240221

241222
# Get model output text directly from response
242223
generated = body.response.output_text or ""

0 commit comments

Comments
 (0)