forked from ydb-platform/ent
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdriver.go
More file actions
353 lines (317 loc) · 8.6 KB
/
driver.go
File metadata and controls
353 lines (317 loc) · 8.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sql
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"strconv"
"strings"
"entgo.io/ent/dialect"
"github.com/ydb-platform/ydb-go-sdk/v3"
)
// Driver is a dialect.Driver implementation for SQL based databases.
type Driver struct {
Conn
dialect string
retryExecutor RetryExecutor
}
// NewDriver creates a new Driver with the given Conn and dialect.
func NewDriver(
dialect string,
c Conn,
retryExecutor RetryExecutor,
) *Driver {
return &Driver{
Conn: c,
dialect: dialect,
retryExecutor: retryExecutor,
}
}
// Open wraps the database/sql.Open method and returns a dialect.Driver that implements the an ent/dialect.Driver interface.
func Open(sqlDialect, dsn string) (*Driver, error) {
var (
db *sql.DB
err error
)
if sqlDialect == dialect.YDB {
nativeDriver, err := ydb.Open(context.Background(), dsn)
if err != nil {
return nil, err
}
conn, err := ydb.Connector(
nativeDriver,
ydb.WithAutoDeclare(),
ydb.WithTablePathPrefix(nativeDriver.Name()),
ydb.WithQueryService(true),
)
if err != nil {
return nil, err
}
db = sql.OpenDB(conn)
} else {
db, err = sql.Open(sqlDialect, dsn)
}
if err != nil {
return nil, err
}
return NewDriver(
sqlDialect,
Conn{db, sqlDialect},
NewRetryExecutor(sqlDialect, db),
), nil
}
// OpenDB wraps the given database/sql.DB method with a Driver.
func OpenDB(sqlDialect string, db *sql.DB) *Driver {
return NewDriver(
sqlDialect,
Conn{db, sqlDialect},
NewRetryExecutor(sqlDialect, db),
)
}
// DB returns the underlying *sql.DB instance.
func (d Driver) DB() *sql.DB {
return d.ExecQuerier.(*sql.DB)
}
// Dialect implements the dialect.Dialect method.
func (d Driver) Dialect() string {
// If the underlying driver is wrapped with a telemetry driver.
for _, name := range []string{dialect.MySQL, dialect.SQLite, dialect.Postgres} {
if strings.HasPrefix(d.dialect, name) {
return name
}
}
return d.dialect
}
// Tx starts and returns a transaction.
func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) {
return d.BeginTx(ctx, nil)
}
// BeginTx starts a transaction with options.
func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) {
tx, err := d.DB().BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{
Conn: Conn{tx, d.dialect},
Tx: tx,
}, nil
}
func (d *Driver) RetryExecutor() RetryExecutor {
return d.retryExecutor
}
// Close closes the underlying connection.
func (d *Driver) Close() error { return d.DB().Close() }
// Tx implements dialect.Tx interface.
type Tx struct {
Conn
driver.Tx
}
// ctyVarsKey is the key used for attaching and reading the context variables.
type ctxVarsKey struct{}
// sessionVars holds sessions/transactions variables to set before every statement.
type sessionVars struct {
vars []struct{ k, v string }
}
// WithVar returns a new context that holds the session variable to be executed before every query.
func WithVar(ctx context.Context, name, value string) context.Context {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
sv.vars = append(sv.vars, struct {
k, v string
}{
k: name,
v: value,
})
return context.WithValue(ctx, ctxVarsKey{}, sv)
}
// VarFromContext returns the session variable value from the context.
func VarFromContext(ctx context.Context, name string) (string, bool) {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
for _, s := range sv.vars {
if s.k == name {
return s.v, true
}
}
return "", false
}
// WithIntVar calls WithVar with the string representation of the value.
func WithIntVar(ctx context.Context, name string, value int) context.Context {
return WithVar(ctx, name, strconv.Itoa(value))
}
// ExecQuerier wraps the standard Exec and Query methods.
type ExecQuerier interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
// Conn implements dialect.ExecQuerier given ExecQuerier.
type Conn struct {
ExecQuerier
dialect string
}
// Exec implements the dialect.Exec method.
func (c Conn) Exec(ctx context.Context, query string, args, v any) (rerr error) {
argv, ok := args.([]any)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", v)
}
ex, cf, err := c.maySetVars(ctx)
if err != nil {
return err
}
if cf != nil {
defer func() { rerr = errors.Join(rerr, cf()) }()
}
switch v := v.(type) {
case nil:
if _, err := ex.ExecContext(ctx, query, argv...); err != nil {
return err
}
case *sql.Result:
res, err := ex.ExecContext(ctx, query, argv...)
if err != nil {
return err
}
*v = res
default:
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Result", v)
}
return nil
}
// Query implements the dialect.Query method.
func (c Conn) Query(ctx context.Context, query string, args, v any) error {
vr, ok := v.(*Rows)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v)
}
argv, ok := args.([]any)
if !ok {
return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", args)
}
ex, cf, err := c.maySetVars(ctx)
if err != nil {
return err
}
rows, err := ex.QueryContext(ctx, query, argv...)
if err != nil {
if cf != nil {
err = errors.Join(err, cf())
}
return err
}
*vr = Rows{rows}
if cf != nil {
vr.ColumnScanner = rowsWithCloser{rows, cf}
}
return nil
}
// maySetVars sets the session variables before executing a query.
func (c Conn) maySetVars(ctx context.Context) (ExecQuerier, func() error, error) {
sv, _ := ctx.Value(ctxVarsKey{}).(sessionVars)
if len(sv.vars) == 0 {
return c, nil, nil
}
var (
ex ExecQuerier // Underlying ExecQuerier.
cf func() error // Close function.
reset []string // Reset variables.
seen = make(map[string]struct{}, len(sv.vars))
)
switch e := c.ExecQuerier.(type) {
case *sql.Tx:
ex = e
case *sql.DB:
conn, err := e.Conn(ctx)
if err != nil {
return nil, nil, err
}
ex, cf = conn, conn.Close
}
for _, s := range sv.vars {
if _, ok := seen[s.k]; !ok {
switch c.dialect {
case dialect.Postgres:
reset = append(reset, fmt.Sprintf("RESET %s", s.k))
case dialect.MySQL:
reset = append(reset, fmt.Sprintf("SET %s = NULL", s.k))
}
seen[s.k] = struct{}{}
}
if _, err := ex.ExecContext(ctx, fmt.Sprintf("SET %s = '%s'", s.k, s.v)); err != nil {
if cf != nil {
err = errors.Join(err, cf())
}
return nil, nil, err
}
}
// If there are variables to reset, and we need to return the
// connection to the pool, we need to clean up the variables.
if cls := cf; cf != nil && len(reset) > 0 {
cf = func() error {
for _, q := range reset {
if _, err := ex.ExecContext(ctx, q); err != nil {
return errors.Join(err, cls())
}
}
return cls()
}
}
return ex, cf, nil
}
var _ dialect.Driver = (*Driver)(nil)
type (
// Rows wraps the sql.Rows to avoid locks copy.
Rows struct{ ColumnScanner }
// Result is an alias to sql.Result.
Result = sql.Result
// NullBool is an alias to sql.NullBool.
NullBool = sql.NullBool
// NullInt64 is an alias to sql.NullInt64.
NullInt64 = sql.NullInt64
// NullString is an alias to sql.NullString.
NullString = sql.NullString
// NullFloat64 is an alias to sql.NullFloat64.
NullFloat64 = sql.NullFloat64
// NullTime represents a time.Time that may be null.
NullTime = sql.NullTime
// TxOptions holds the transaction options to be used in DB.BeginTx.
TxOptions = sql.TxOptions
)
// NullScanner implements the sql.Scanner interface such that it
// can be used as a scan destination, similar to the types above.
type NullScanner struct {
S sql.Scanner
Valid bool // Valid is true if the Scan value is not NULL.
}
// Scan implements the Scanner interface.
func (n *NullScanner) Scan(value any) error {
n.Valid = value != nil
if n.Valid {
return n.S.Scan(value)
}
return nil
}
// ColumnScanner is the interface that wraps the standard
// sql.Rows methods used for scanning database rows.
type ColumnScanner interface {
Close() error
ColumnTypes() ([]*sql.ColumnType, error)
Columns() ([]string, error)
Err() error
Next() bool
NextResultSet() bool
Scan(dest ...any) error
}
// rowsWithCloser wraps the ColumnScanner interface with a custom Close hook.
type rowsWithCloser struct {
ColumnScanner
closer func() error
}
// Close closes the underlying ColumnScanner and calls the custom closer.
func (r rowsWithCloser) Close() error {
err := r.ColumnScanner.Close()
return errors.Join(err, r.closer())
}