From f5f822931cca80abd1495acfa8cedbf20835a9bc Mon Sep 17 00:00:00 2001 From: Adel Mohamed Date: Mon, 9 Feb 2026 22:28:43 +0200 Subject: [PATCH] fix(security): prevent potential SQL injection in TableCounts --- gtfsdb/debugging.go | 50 +++++++++++++++++----------- gtfsdb/debugging_test.go | 39 ++++++++++++++++++++++ internal/restapi/test_helper_test.go | 4 +-- 3 files changed, 72 insertions(+), 21 deletions(-) create mode 100644 gtfsdb/debugging_test.go diff --git a/gtfsdb/debugging.go b/gtfsdb/debugging.go index 16e2420b..3e577b8f 100644 --- a/gtfsdb/debugging.go +++ b/gtfsdb/debugging.go @@ -75,30 +75,42 @@ func (c *Client) TableCounts() (map[string]int, error) { counts := make(map[string]int) - // Whitelist allowed table names to prevent SQL injection - allowedTables := map[string]bool{ - "agencies": true, - "routes": true, - "stops": true, - "trips": true, - "stop_times": true, - "calendar": true, - "calendar_dates": true, - "shapes": true, - "transfers": true, - "feed_info": true, - "block_trip_index": true, - "block_trip_entry": true, - "import_metadata": true, - } - for _, table := range tables { - if !allowedTables[table] { + var query string + + // This prevents SQL injection by ensuring the query string is always a constant. + switch table { + case "agencies": + query = "SELECT COUNT(*) FROM agencies" + case "routes": + query = "SELECT COUNT(*) FROM routes" + case "stops": + query = "SELECT COUNT(*) FROM stops" + case "trips": + query = "SELECT COUNT(*) FROM trips" + case "stop_times": + query = "SELECT COUNT(*) FROM stop_times" + case "calendar": + query = "SELECT COUNT(*) FROM calendar" + case "calendar_dates": + query = "SELECT COUNT(*) FROM calendar_dates" + case "shapes": + query = "SELECT COUNT(*) FROM shapes" + case "transfers": + query = "SELECT COUNT(*) FROM transfers" + case "feed_info": + query = "SELECT COUNT(*) FROM feed_info" + case "block_trip_index": + query = "SELECT COUNT(*) FROM block_trip_index" + case "block_trip_entry": + query = "SELECT COUNT(*) FROM block_trip_entry" + case "import_metadata": + query = "SELECT COUNT(*) FROM import_metadata" + default: continue } var count int - query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) err := c.DB.QueryRow(query).Scan(&count) if err != nil { return nil, err diff --git a/gtfsdb/debugging_test.go b/gtfsdb/debugging_test.go new file mode 100644 index 00000000..d35550e4 --- /dev/null +++ b/gtfsdb/debugging_test.go @@ -0,0 +1,39 @@ +package gtfsdb + +import ( + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTableCounts(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + defer db.Close() + + client := &Client{DB: db} + + _, err = db.Exec(` + CREATE TABLE agencies (id TEXT); + INSERT INTO agencies VALUES ('1'); + + CREATE TABLE stops (id TEXT); + INSERT INTO stops VALUES ('s1'), ('s2'); + + -- Create a table NOT in the whitelist to ensure it's ignored + CREATE TABLE secret_table (id TEXT); + `) + require.NoError(t, err) + + counts, err := client.TableCounts() + require.NoError(t, err) + + assert.Equal(t, 1, counts["agencies"], "Should count agencies correctly") + assert.Equal(t, 2, counts["stops"], "Should count stops correctly") + + _, exists := counts["secret_table"] + assert.False(t, exists, "Should not include tables outside the whitelist") +} diff --git a/internal/restapi/test_helper_test.go b/internal/restapi/test_helper_test.go index c9b6b24b..b36ad39b 100644 --- a/internal/restapi/test_helper_test.go +++ b/internal/restapi/test_helper_test.go @@ -72,11 +72,11 @@ func TestCollectAllNestedIdsFromObjectsFailures(t *testing.T) { mockFatalf := &mockTestingFatalf{} var running sync.WaitGroup + running.Add(1) go func() { defer running.Done() collectAllNestedIdsFromObjects(mockFatalf, tt.data, "routes") }() - running.Add(1) running.Wait() assert.True(t, mockFatalf.failed) @@ -130,11 +130,11 @@ func TestCollectAllIdsFromObjectsFailures(t *testing.T) { mockFatalf := &mockTestingFatalf{} var running sync.WaitGroup + running.Add(1) go func() { defer running.Done() collectAllIdsFromObjects(mockFatalf, tt.data, "id") }() - running.Add(1) running.Wait() assert.True(t, mockFatalf.failed)