Skip to content

Commit 4105743

Browse files
Support multi-statement SQL queries for db_execute_query tool (#123)
1 parent 8448e5b commit 4105743

File tree

2 files changed

+125
-26
lines changed

2 files changed

+125
-26
lines changed

internal/tiger/mcp/db_tools.go

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,22 @@ type DBExecuteQueryColumn struct {
6363
// DBExecuteQueryOutput represents output for db_execute_query
6464
type DBExecuteQueryOutput struct {
6565
Columns []DBExecuteQueryColumn `json:"columns,omitempty"`
66-
Rows [][]any `json:"rows,omitempty"`
66+
Rows *[][]any `json:"rows,omitempty"`
6767
RowsAffected int64 `json:"rows_affected"`
6868
ExecutionTime string `json:"execution_time"`
6969
}
7070

7171
func (DBExecuteQueryOutput) Schema() *jsonschema.Schema {
7272
schema := util.Must(jsonschema.For[DBExecuteQueryOutput](nil))
7373

74-
schema.Properties["columns"].Description = "Column metadata from the query result including name and PostgreSQL type"
74+
schema.Properties["columns"].Description = "Column metadata from the query result including name and PostgreSQL type. Omitted for commands that don't return rows (INSERT, UPDATE, DELETE, etc.)"
7575
schema.Properties["columns"].Examples = []any{[]DBExecuteQueryColumn{
7676
{Name: "id", Type: "int4"},
7777
{Name: "name", Type: "text"},
7878
{Name: "created_at", Type: "timestamptz"},
7979
}}
8080

81-
schema.Properties["rows"].Description = "Result rows as arrays of values. Empty for commands that don't return rows (INSERT, UPDATE, DELETE, etc.)"
81+
schema.Properties["rows"].Description = "Result rows as arrays of values. Omitted for commands that don't return rows (INSERT, UPDATE, DELETE, etc.)"
8282
schema.Properties["rows"].Examples = []any{[][]any{{1, "alice", "2024-01-01"}, {2, "bob", "2024-01-02"}}}
8383

8484
schema.Properties["rows_affected"].Description = "Number of rows affected by the query. For SELECT, this is the number of rows returned. For INSERT/UPDATE/DELETE, this is the number of rows modified. Returns 0 for statements that don't return or modify rows (e.g. CREATE TABLE)."
@@ -95,9 +95,11 @@ func (s *Server) registerDatabaseTools() {
9595
mcp.AddTool(s.mcpServer, &mcp.Tool{
9696
Name: "db_execute_query",
9797
Title: "Execute SQL Query",
98-
Description: `Execute a single SQL query against a service database.
98+
Description: `Execute SQL queries against a service database.
9999
100-
This tool connects to a PostgreSQL database service in Tiger Cloud and executes the provided SQL query, returning the results with column names, row data, and execution metadata. Multi-statement queries are not supported.
100+
This tool connects to a PostgreSQL database service in Tiger Cloud and executes the provided SQL query, returning the results with column names, row data, and execution metadata.
101+
102+
Multi-statement queries are supported when no parameters are provided. When executing multiple statements separated by semicolons, all statements are executed in a single transaction, and only the results from the final statement are returned. Multi-statement queries with parameters are not supported and will return an error.
101103
102104
WARNING: Use with caution - this tool can execute any SQL statement including INSERT, UPDATE, DELETE, and DDL commands. Always review queries before execution.`,
103105
InputSchema: DBExecuteQueryInput{}.Schema(),
@@ -157,59 +159,149 @@ func (s *Server) handleDBExecuteQuery(ctx context.Context, req *mcp.CallToolRequ
157159
queryCtx, cancel := context.WithTimeout(ctx, timeout)
158160
defer cancel()
159161

162+
// Parse connection string into config
163+
connConfig, err := pgx.ParseConfig(details.String())
164+
if err != nil {
165+
return nil, DBExecuteQueryOutput{}, fmt.Errorf("failed to parse connection string: %w", err)
166+
}
167+
168+
// Choose query execution mode based on whether parameters are present.
169+
// Simple protocol supports multi-statement queries but interpolates
170+
// parameters client-side (which we don't want to do, for security's sake).
171+
// Extended protocol sends parameters separately but doesn't support
172+
// multi-statement queries. This means we don't support multi-statement
173+
// queries with parameters (pgx will return an error for them when using
174+
// QueryExecModeDescribeExec). See [pgx.QueryExecMode] for details.
175+
if len(input.Parameters) > 0 {
176+
// Use extended protocol to send parameters separately (more secure,
177+
// but doesn't support multi-statement queries).
178+
connConfig.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
179+
} else {
180+
// Use simple protocol to support multi-statement queries when no
181+
// parameters are given.
182+
connConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
183+
}
184+
160185
// Connect to database
161-
conn, err := pgx.Connect(queryCtx, details.String())
186+
conn, err := pgx.ConnectConfig(queryCtx, connConfig)
162187
if err != nil {
163-
return nil, DBExecuteQueryOutput{}, fmt.Errorf("failed to connect to database: %w", err)
188+
return nil, DBExecuteQueryOutput{}, err
164189
}
165190
defer conn.Close(context.Background())
166191

167192
// Execute query and measure time
168193
startTime := time.Now()
169-
rows, err := conn.Query(queryCtx, input.Query, util.ConvertSliceToAny(input.Parameters)...)
170-
if err != nil {
171-
return nil, DBExecuteQueryOutput{}, fmt.Errorf("query execution failed: %w", err)
194+
195+
// Queue the query. When using QueryExecModeSimpleProtocol (no parameters),
196+
// it's valid to queue a single multi-statement SQL query as the batch.
197+
// See the [pgx.Batch.Queue] documentation for details. When using
198+
// QueryExecModeDescribeExec (with parameters), queueing a multi-statement
199+
// query here will result in an error when executing it below.
200+
batch := &pgx.Batch{}
201+
batch.Queue(input.Query, util.ConvertSliceToAny(input.Parameters)...)
202+
203+
br := conn.SendBatch(queryCtx, batch)
204+
defer br.Close()
205+
206+
// Process all result sets, keeping only the final one
207+
var finalResult resultSet
208+
for {
209+
rows, err := br.Query()
210+
if err != nil {
211+
// Check if we've reached the final result set and stop iteration.
212+
// NOTE: It would be nice if there was a real sentinel error type
213+
// we could check here instead of comparing error strings, but pgx
214+
// doesn't expose one. We will just need to verify that the error
215+
// message doesn't change when we update the pgx dependency.
216+
if err.Error() == "no more results in batch" {
217+
break
218+
}
219+
return nil, DBExecuteQueryOutput{}, err
220+
}
221+
222+
// Process this result set
223+
result, err := processResultSet(conn, rows)
224+
if err != nil {
225+
return nil, DBExecuteQueryOutput{}, err
226+
}
227+
228+
// Save this result set as the current "final" one
229+
finalResult = result
230+
}
231+
232+
if err := br.Close(); err != nil {
233+
return nil, DBExecuteQueryOutput{}, err
172234
}
235+
236+
// Build output from the final result set
237+
output := DBExecuteQueryOutput{
238+
Columns: finalResult.columns,
239+
Rows: finalResult.rows,
240+
RowsAffected: finalResult.rowsAffected,
241+
ExecutionTime: time.Since(startTime).String(),
242+
}
243+
244+
return nil, output, nil
245+
}
246+
247+
// resultSet holds the columns, rows, and metadata from a single query result set
248+
type resultSet struct {
249+
columns []DBExecuteQueryColumn
250+
rows *[][]any
251+
rowsAffected int64
252+
}
253+
254+
// processResultSet reads all data from a pgx.Rows result set
255+
func processResultSet(conn *pgx.Conn, rows pgx.Rows) (resultSet, error) {
173256
defer rows.Close()
174257

175258
// Get column metadata from field descriptions
176259
fieldDescriptions := rows.FieldDescriptions()
177-
var columns []DBExecuteQueryColumn
178-
for _, fd := range fieldDescriptions {
260+
columns := make([]DBExecuteQueryColumn, len(fieldDescriptions))
261+
for i, fd := range fieldDescriptions {
179262
// Get the type name from the connection's type map
180-
dataType, ok := conn.TypeMap().TypeForOID(fd.DataTypeOID)
181263
typeName := "unknown"
264+
dataType, ok := conn.TypeMap().TypeForOID(fd.DataTypeOID)
182265
if ok && dataType != nil {
183266
typeName = dataType.Name
184267
}
185-
columns = append(columns, DBExecuteQueryColumn{
268+
columns[i] = DBExecuteQueryColumn{
186269
Name: fd.Name,
187270
Type: typeName,
188-
})
271+
}
189272
}
190273

191-
// Collect all rows
274+
// Collect all rows from this result set
192275
var resultRows [][]any
276+
if len(columns) > 0 {
277+
// If any columns were returned, initialize resultRows to an empty
278+
// slice to ensure we always return a JSON array in the results, even
279+
// if empty (we want to be completely clear when a SELECT query returns
280+
// no rows). On the other hand, if no columns were returned, it's not a
281+
// result returning query (e.g. it's DDL or an INSERT/UPDATE/DELETE/etc.),
282+
// so we leave resultRows nil so it gets omitted from the JSON result.
283+
resultRows = make([][]any, 0)
284+
}
193285
for rows.Next() {
194286
// Scan values into generic interface slice
195287
values, err := rows.Values()
196288
if err != nil {
197-
return nil, DBExecuteQueryOutput{}, fmt.Errorf("failed to scan row: %w", err)
289+
return resultSet{}, err
198290
}
199291
resultRows = append(resultRows, values)
200292
}
201293

202294
// Check for errors during iteration
203-
if rows.Err() != nil {
204-
return nil, DBExecuteQueryOutput{}, fmt.Errorf("error during row iteration: %w", rows.Err())
295+
if err := rows.Err(); err != nil {
296+
return resultSet{}, err
205297
}
206298

207-
output := DBExecuteQueryOutput{
208-
Columns: columns,
209-
Rows: resultRows,
210-
RowsAffected: rows.CommandTag().RowsAffected(),
211-
ExecutionTime: time.Since(startTime).String(),
212-
}
299+
// Get rows affected
300+
rowsAffected := rows.CommandTag().RowsAffected()
213301

214-
return nil, output, nil
302+
return resultSet{
303+
columns: columns,
304+
rows: util.PtrIfNonNil(resultRows),
305+
rowsAffected: rowsAffected,
306+
}, nil
215307
}

internal/tiger/util/convert.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ func Ptr[T any](val T) *T {
1111
return &val
1212
}
1313

14+
func PtrIfNonNil[T ~[]E, E any](val T) *T {
15+
if val == nil {
16+
return nil
17+
}
18+
return &val
19+
}
20+
1421
func Deref[T any](val *T) T {
1522
if val == nil {
1623
var res T

0 commit comments

Comments
 (0)