Skip to content

Commit a5131b5

Browse files
committed
WIP
1 parent 7609ebc commit a5131b5

File tree

420 files changed

+13950
-13701
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

420 files changed

+13950
-13701
lines changed

internal/analyzer/analyzer.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ import (
1515
"github.com/sqlc-dev/sqlc/internal/cache"
1616
"github.com/sqlc-dev/sqlc/internal/config"
1717
"github.com/sqlc-dev/sqlc/internal/info"
18-
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1918
"github.com/sqlc-dev/sqlc/internal/sql/named"
19+
"github.com/sqlc-dev/sqlc/pkg/ast"
2020
)
2121

2222
type CachedAnalyzer struct {
@@ -34,6 +34,14 @@ func Cached(a Analyzer, c config.Config, db config.Database) *CachedAnalyzer {
3434
}
3535
}
3636

37+
// Expand delegates to the underlying analyzer if it supports expansion.
38+
func (c *CachedAnalyzer) Expand(ctx context.Context, query string) (string, error) {
39+
if analyzerExpander, ok := c.a.(AnalyzerExpander); ok {
40+
return analyzerExpander.Expand(ctx, query)
41+
}
42+
return "", fmt.Errorf("analyzer does not support query expansion")
43+
}
44+
3745
// Create a new error here
3846

3947
func (c *CachedAnalyzer) Analyze(ctx context.Context, n ast.Node, q string, schema []string, np *named.ParamSet) (*analysis.Analysis, error) {
@@ -128,3 +136,13 @@ type Analyzer interface {
128136
// This is used for star expansion in database-only mode.
129137
GetColumnNames(ctx context.Context, query string) ([]string, error)
130138
}
139+
140+
// AnalyzerExpander is an optional interface for analyzers that support query expansion.
141+
// The parser and dialect are stored in the analyzer when it's created.
142+
type AnalyzerExpander interface {
143+
Analyzer
144+
145+
// Expand expands a SQL query by replacing * with explicit column names.
146+
// Each analyzer knows how to implement expansion using its own parser and dialect.
147+
Expand(ctx context.Context, query string) (string, error)
148+
}

internal/cmd/parse.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
"github.com/sqlc-dev/sqlc/internal/engine/dolphin"
1212
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
1313
"github.com/sqlc-dev/sqlc/internal/engine/sqlite"
14-
"github.com/sqlc-dev/sqlc/internal/sql/ast"
14+
"github.com/sqlc-dev/sqlc/pkg/ast"
1515
)
1616

1717
var parseCmd = &cobra.Command{

internal/compiler/analyze.go

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package compiler
22

33
import (
4+
"fmt"
45
"sort"
56

67
analyzer "github.com/sqlc-dev/sqlc/internal/analysis"
78
"github.com/sqlc-dev/sqlc/internal/config"
89
"github.com/sqlc-dev/sqlc/internal/source"
9-
"github.com/sqlc-dev/sqlc/internal/sql/ast"
10+
"github.com/sqlc-dev/sqlc/pkg/ast"
1011
"github.com/sqlc-dev/sqlc/internal/sql/named"
1112
"github.com/sqlc-dev/sqlc/internal/sql/rewrite"
1213
"github.com/sqlc-dev/sqlc/internal/sql/validate"
@@ -134,31 +135,36 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
134135
return nil
135136
}
136137

137-
numbers, dollar, err := validate.ParamRef(raw)
138+
_, dollar, err := validate.ParamRef(raw.Stmt)
138139
if err := check(err); err != nil {
139140
return nil, err
140141
}
141142

142-
raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar)
143+
// TODO: fix rewrite.NamedParameters - function not found
144+
var namedParams map[string]int
145+
var edits []source.Edit
143146

144147
var table *ast.TableName
145-
switch n := raw.Stmt.(type) {
146-
case *ast.InsertStmt:
147-
if err := check(validate.InsertStmt(n)); err != nil {
148-
return nil, err
149-
}
150-
var err error
151-
table, err = ParseTableName(n.Relation)
152-
if err := check(err); err != nil {
153-
return nil, err
148+
if raw.Stmt != nil && raw.Stmt.Node != nil {
149+
switch n := raw.Stmt.Node.(type) {
150+
case *ast.Node_InsertStmt:
151+
if err := check(validate.InsertStmt(n.InsertStmt)); err != nil {
152+
return nil, err
153+
}
154+
var err error
155+
relNode := &ast.Node{Node: &ast.Node_RangeVar{RangeVar: n.InsertStmt.Relation}}
156+
table, err = ParseTableName(*relNode)
157+
if err := check(err); err != nil {
158+
return nil, err
159+
}
154160
}
155161
}
156162

157-
if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil {
163+
if err := check(validate.FuncCall(c.catalog, c.combo, raw.Stmt)); err != nil {
158164
return nil, err
159165
}
160166

161-
if err := check(validate.In(c.catalog, raw)); err != nil {
167+
if err := check(validate.In(c.catalog, raw.Stmt)); err != nil {
162168
return nil, err
163169
}
164170
rvs := rangeVars(raw.Stmt)
@@ -176,16 +182,26 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
176182
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
177183
}
178184
raw, embeds := rewrite.Embeds(raw)
179-
qc, err := c.buildQueryCatalog(c.catalog, raw.Stmt, embeds)
185+
if raw.Stmt == nil {
186+
return nil, fmt.Errorf("raw.Stmt is nil")
187+
}
188+
qc, err := c.buildQueryCatalog(c.catalog, *raw.Stmt, embeds)
180189
if err := check(err); err != nil {
181190
return nil, err
182191
}
183192

184-
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
193+
var paramSet *named.ParamSet
194+
if namedParams != nil {
195+
paramSet = named.NewParamSet(nil, true)
196+
for k := range namedParams {
197+
paramSet.Add(named.NewParam(k))
198+
}
199+
}
200+
params, err := c.resolveCatalogRefs(qc, rvs, refs, paramSet, embeds)
185201
if err := check(err); err != nil {
186202
return nil, err
187203
}
188-
cols, err := c.outputColumns(qc, raw.Stmt)
204+
cols, err := c.outputColumns(qc, *raw.Stmt)
189205
if err := check(err); err != nil {
190206
return nil, err
191207
}
@@ -194,7 +210,9 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
194210
if check(err); err != nil {
195211
return nil, err
196212
}
197-
edits = append(edits, expandEdits...)
213+
if expandEdits != nil {
214+
edits = append(edits, expandEdits...)
215+
}
198216
expanded, err := source.Mutate(query, edits)
199217
if err != nil {
200218
return nil, err
@@ -205,11 +223,18 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
205223
rerr = errors[0]
206224
}
207225

226+
var namedParamSet *named.ParamSet
227+
if namedParams != nil {
228+
namedParamSet = named.NewParamSet(nil, true)
229+
for k := range namedParams {
230+
namedParamSet.Add(named.NewParam(k))
231+
}
232+
}
208233
return &analysis{
209234
Table: table,
210235
Columns: cols,
211236
Parameters: params,
212237
Query: expanded,
213-
Named: namedParams,
238+
Named: namedParamSet,
214239
}, rerr
215240
}

internal/compiler/compat.go

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@ import (
44
"fmt"
55
"strings"
66

7-
"github.com/sqlc-dev/sqlc/internal/sql/ast"
7+
"github.com/sqlc-dev/sqlc/pkg/ast"
88
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
99
)
1010

1111
// This is mainly copy-pasted from internal/postgresql/parse.go
1212
func stringSlice(list *ast.List) []string {
1313
items := []string{}
1414
for _, item := range list.Items {
15-
if n, ok := item.(*ast.String); ok {
16-
items = append(items, n.Str)
15+
if item != nil && item.Node != nil {
16+
if strNode, ok := item.Node.(*ast.Node_String_); ok {
17+
items = append(items, strNode.String_.Str)
18+
}
1719
}
1820
}
1921
return items
@@ -26,18 +28,18 @@ type Relation struct {
2628
}
2729

2830
func parseRelation(node ast.Node) (*Relation, error) {
29-
switch n := node.(type) {
30-
case *ast.Boolean:
31-
if n == nil {
32-
return nil, fmt.Errorf("unexpected nil in %T node", n)
33-
}
31+
if node.Node == nil {
32+
return nil, fmt.Errorf("unexpected nil node")
33+
}
34+
switch n := node.Node.(type) {
35+
case *ast.Node_Boolean:
3436
return &Relation{Name: "bool"}, nil
3537

36-
case *ast.List:
37-
if n == nil {
38-
return nil, fmt.Errorf("unexpected nil in %T node", n)
38+
case *ast.Node_List:
39+
if n.List == nil {
40+
return nil, fmt.Errorf("unexpected nil in List node")
3941
}
40-
parts := stringSlice(n)
42+
parts := stringSlice(n.List)
4143
switch len(parts) {
4244
case 1:
4345
return &Relation{
@@ -55,37 +57,40 @@ func parseRelation(node ast.Node) (*Relation, error) {
5557
Name: parts[2],
5658
}, nil
5759
default:
58-
return nil, fmt.Errorf("invalid name: %s", astutils.Join(n, "."))
60+
return nil, fmt.Errorf("invalid name: %s", astutils.Join(n.List, "."))
5961
}
6062

61-
case *ast.RangeVar:
62-
if n == nil {
63-
return nil, fmt.Errorf("unexpected nil in %T node", n)
63+
case *ast.Node_RangeVar:
64+
if n.RangeVar == nil {
65+
return nil, fmt.Errorf("unexpected nil in RangeVar node")
6466
}
67+
rv := n.RangeVar
6568
name := Relation{}
66-
if n.Catalogname != nil {
67-
name.Catalog = *n.Catalogname
69+
if rv.Catalogname != "" {
70+
name.Catalog = rv.Catalogname
6871
}
69-
if n.Schemaname != nil {
70-
name.Schema = *n.Schemaname
72+
if rv.Schemaname != "" {
73+
name.Schema = rv.Schemaname
7174
}
72-
if n.Relname != nil {
73-
name.Name = *n.Relname
75+
if rv.Relname != "" {
76+
name.Name = rv.Relname
7477
}
7578
return &name, nil
7679

77-
case *ast.TypeName:
78-
if n == nil {
79-
return nil, fmt.Errorf("unexpected nil in %T node", n)
80+
case *ast.Node_TypeName:
81+
if n.TypeName == nil {
82+
return nil, fmt.Errorf("unexpected nil in TypeName node")
8083
}
81-
if n.Names != nil {
82-
return parseRelation(n.Names)
84+
tn := n.TypeName
85+
if tn.Names != nil {
86+
namesNode := ast.Node{Node: &ast.Node_List{List: tn.Names}}
87+
return parseRelation(namesNode)
8388
} else {
84-
return &Relation{Name: n.Name}, nil
89+
return &Relation{Name: tn.Name}, nil
8590
}
8691

8792
default:
88-
return nil, fmt.Errorf("unexpected node type: %T", node)
93+
return nil, fmt.Errorf("unexpected node type: %T", n)
8994
}
9095
}
9196

internal/compiler/compile.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"github.com/sqlc-dev/sqlc/internal/opts"
1515
"github.com/sqlc-dev/sqlc/internal/rpc"
1616
"github.com/sqlc-dev/sqlc/internal/source"
17-
"github.com/sqlc-dev/sqlc/internal/sql/ast"
17+
"github.com/sqlc-dev/sqlc/pkg/ast"
1818
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
1919
"github.com/sqlc-dev/sqlc/internal/sql/sqlpath"
2020
)
@@ -56,7 +56,11 @@ func (c *Compiler) parseCatalog(schemas []string) error {
5656

5757
for i := range stmts {
5858
if err := c.catalog.Update(stmts[i], c); err != nil {
59-
merr.Add(filename, contents, stmts[i].Pos(), err)
59+
loc := int32(0)
60+
if stmts[i].Raw != nil {
61+
loc = stmts[i].Raw.StmtLocation
62+
}
63+
merr.Add(filename, contents, int(loc), err)
6064
continue
6165
}
6266
}
@@ -97,10 +101,14 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
97101
continue
98102
}
99103
for _, stmt := range stmts {
100-
query, err := c.parseQuery(stmt.Raw, src, o)
104+
if stmt.Raw == nil || stmt.Raw.Stmt == nil {
105+
merr.Add(filename, src, 0, fmt.Errorf("stmt.Raw or stmt.Raw.Stmt is nil"))
106+
continue
107+
}
108+
query, err := c.parseQuery(*stmt.Raw.Stmt, src, o)
101109
if err != nil {
102110
var e *sqlerr.Error
103-
loc := stmt.Raw.Pos()
111+
loc := int(stmt.Raw.StmtLocation)
104112
if errors.As(err, &e) && e.Location != 0 {
105113
loc = e.Location
106114
}
@@ -118,7 +126,11 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
118126
queryName := query.Metadata.Name
119127
if queryName != "" {
120128
if _, exists := set[queryName]; exists {
121-
merr.Add(filename, src, stmt.Raw.Pos(), fmt.Errorf("duplicate query name: %s", queryName))
129+
loc := 0
130+
if stmt.Raw != nil {
131+
loc = int(stmt.Raw.StmtLocation)
132+
}
133+
merr.Add(filename, src, loc, fmt.Errorf("duplicate query name: %s", queryName))
122134
continue
123135
}
124136
set[queryName] = struct{}{}

0 commit comments

Comments
 (0)