Skip to content

Commit 59cb5e2

Browse files
sql/ydb: tried to use driver unwrap from ydb-go-sdk
1 parent b96ad48 commit 59cb5e2

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

sql/ydb/driver.go

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"ariga.io/atlas/sql/migrate"
1818
"ariga.io/atlas/sql/schema"
1919
"ariga.io/atlas/sql/sqlclient"
20+
"github.com/ydb-platform/ydb-go-sdk/v3"
2021
ydbSdk "github.com/ydb-platform/ydb-go-sdk/v3"
2122
)
2223

@@ -59,6 +60,7 @@ func init() {
5960
sqlclient.Register(
6061
DriverName,
6162
sqlclient.OpenerFunc(opener),
63+
sqlclient.RegisterDriverOpener(Open),
6264
sqlclient.RegisterURLParser(parser{}),
6365
)
6466
}
@@ -81,7 +83,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
8183
}
8284

8385
sqlDriver := sql.OpenDB(conn)
84-
migrateDriver, err := Open(nativeDriver, sqlDriver)
86+
migrateDriver, err := Open(sqlDriver)
8587
if err != nil {
8688
if cerr := sqlDriver.Close(); cerr != nil {
8789
err = fmt.Errorf("%w: %v", err, cerr)
@@ -102,14 +104,35 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
102104
}
103105

104106
// Open opens a new YDB driver.
105-
func Open(nativeDriver *ydbSdk.Driver, sqlDriver *sql.DB) (migrate.Driver, error) {
107+
func Open(db schema.ExecQuerier) (migrate.Driver, error) {
108+
var (
109+
execQuerier schema.ExecQuerier
110+
nativeDriver *ydb.Driver
111+
err error
112+
)
113+
114+
switch casted := db.(type) {
115+
case *sql.DB:
116+
nativeDriver, err = ydb.Unwrap(casted)
117+
execQuerier = casted
118+
case *sql.Conn:
119+
nativeDriver, err = ydb.Unwrap(casted)
120+
execQuerier = casted
121+
default:
122+
return nil, fmt.Errorf("ydb: expected *sql.DB or *sql.Conn but got %T", db)
123+
}
124+
125+
if err != nil {
126+
return nil, fmt.Errorf("ydb: failed to unwrap ydb native driver: %v", err)
127+
}
128+
106129
c := &conn{
107-
ExecQuerier: sqlDriver,
130+
ExecQuerier: execQuerier,
108131
nativeDriver: nativeDriver,
109132
database: nativeDriver.Name(),
110133
}
111134

112-
rows, err := sqlDriver.QueryContext(context.Background(), "SELECT version()")
135+
rows, err := execQuerier.QueryContext(context.Background(), "SELECT version()")
113136
if err != nil {
114137
return nil, fmt.Errorf("ydb: failed to query version: %w", err)
115138
}

0 commit comments

Comments
 (0)