Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions gtfsdb/debugging.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,42 @@ func (c *Client) TableCounts() (map[string]int, error) {

counts := make(map[string]int)

// Whitelist allowed table names to prevent SQL injection
allowedTables := map[string]bool{
"agencies": true,
"routes": true,
"stops": true,
"trips": true,
"stop_times": true,
"calendar": true,
"calendar_dates": true,
"shapes": true,
"transfers": true,
"feed_info": true,
"block_trip_index": true,
"block_trip_entry": true,
"import_metadata": true,
}

for _, table := range tables {
if !allowedTables[table] {
var query string

// This prevents SQL injection by ensuring the query string is always a constant.
switch table {
case "agencies":
query = "SELECT COUNT(*) FROM agencies"
case "routes":
query = "SELECT COUNT(*) FROM routes"
case "stops":
query = "SELECT COUNT(*) FROM stops"
case "trips":
query = "SELECT COUNT(*) FROM trips"
case "stop_times":
query = "SELECT COUNT(*) FROM stop_times"
case "calendar":
query = "SELECT COUNT(*) FROM calendar"
case "calendar_dates":
query = "SELECT COUNT(*) FROM calendar_dates"
case "shapes":
query = "SELECT COUNT(*) FROM shapes"
case "transfers":
query = "SELECT COUNT(*) FROM transfers"
case "feed_info":
query = "SELECT COUNT(*) FROM feed_info"
case "block_trip_index":
query = "SELECT COUNT(*) FROM block_trip_index"
case "block_trip_entry":
query = "SELECT COUNT(*) FROM block_trip_entry"
case "import_metadata":
query = "SELECT COUNT(*) FROM import_metadata"
default:
continue
}

var count int
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
err := c.DB.QueryRow(query).Scan(&count)
if err != nil {
return nil, err
Expand Down
39 changes: 39 additions & 0 deletions gtfsdb/debugging_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package gtfsdb

import (
"database/sql"
"testing"

_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestTableCounts(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
require.NoError(t, err)
defer db.Close()

client := &Client{DB: db}

_, err = db.Exec(`
CREATE TABLE agencies (id TEXT);
INSERT INTO agencies VALUES ('1');

CREATE TABLE stops (id TEXT);
INSERT INTO stops VALUES ('s1'), ('s2');

-- Create a table NOT in the whitelist to ensure it's ignored
CREATE TABLE secret_table (id TEXT);
`)
require.NoError(t, err)

counts, err := client.TableCounts()
require.NoError(t, err)

assert.Equal(t, 1, counts["agencies"], "Should count agencies correctly")
assert.Equal(t, 2, counts["stops"], "Should count stops correctly")

_, exists := counts["secret_table"]
assert.False(t, exists, "Should not include tables outside the whitelist")
}
4 changes: 2 additions & 2 deletions internal/restapi/test_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ func TestCollectAllNestedIdsFromObjectsFailures(t *testing.T) {
mockFatalf := &mockTestingFatalf{}

var running sync.WaitGroup
running.Add(1)
go func() {
defer running.Done()
collectAllNestedIdsFromObjects(mockFatalf, tt.data, "routes")
}()
running.Add(1)
running.Wait()

assert.True(t, mockFatalf.failed)
Expand Down Expand Up @@ -130,11 +130,11 @@ func TestCollectAllIdsFromObjectsFailures(t *testing.T) {
mockFatalf := &mockTestingFatalf{}

var running sync.WaitGroup
running.Add(1)
go func() {
defer running.Done()
collectAllIdsFromObjects(mockFatalf, tt.data, "id")
}()
running.Add(1)
running.Wait()

assert.True(t, mockFatalf.failed)
Expand Down
Loading