Skip to content

Commit 38e1813

Browse files
sql/ydb: tried to use driver unwrap from ydb-go-sdk (ydb-platform#4)
1 parent 84bf504 commit 38e1813

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

sql/ydb/driver.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"ariga.io/atlas/sql/migrate"
1818
"ariga.io/atlas/sql/schema"
1919
"ariga.io/atlas/sql/sqlclient"
20+
"github.com/ydb-platform/ydb-go-sdk/v3"
2021
ydbSdk "github.com/ydb-platform/ydb-go-sdk/v3"
2122
)
2223

@@ -59,6 +60,7 @@ func init() {
5960
sqlclient.Register(
6061
DriverName,
6162
sqlclient.OpenerFunc(opener),
63+
sqlclient.RegisterDriverOpener(Open),
6264
sqlclient.RegisterURLParser(parser{}),
6365
)
6466
}
@@ -81,7 +83,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
8183
}
8284

8385
sqlDriver := sql.OpenDB(conn)
84-
migrateDriver, err := Open(nativeDriver, sqlDriver)
86+
migrateDriver, err := Open(sqlDriver)
8587
if err != nil {
8688
if cerr := sqlDriver.Close(); cerr != nil {
8789
err = fmt.Errorf("%w: %v", err, cerr)
@@ -102,14 +104,32 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
102104
}
103105

104106
// Open opens a new YDB driver.
105-
func Open(nativeDriver *ydbSdk.Driver, sqlDriver *sql.DB) (migrate.Driver, error) {
107+
func Open(db schema.ExecQuerier) (migrate.Driver, error) {
108+
var (
109+
nativeDriver *ydb.Driver
110+
err error
111+
)
112+
113+
switch casted := db.(type) {
114+
case *sql.DB:
115+
nativeDriver, err = ydb.Unwrap(casted)
116+
case *sql.Conn:
117+
nativeDriver, err = ydb.Unwrap(casted)
118+
default:
119+
return nil, fmt.Errorf("ydb: expected *sql.DB or *sql.Conn but got %T", db)
120+
}
121+
122+
if err != nil {
123+
return nil, fmt.Errorf("ydb: failed to unwrap ydb native driver: %v", err)
124+
}
125+
106126
c := &conn{
107-
ExecQuerier: sqlDriver,
127+
ExecQuerier: db,
108128
nativeDriver: nativeDriver,
109129
database: nativeDriver.Name(),
110130
}
111131

112-
rows, err := sqlDriver.QueryContext(context.Background(), "SELECT version()")
132+
rows, err := db.QueryContext(context.Background(), "SELECT version()")
113133
if err != nil {
114134
return nil, fmt.Errorf("ydb: failed to query version: %w", err)
115135
}

0 commit comments

Comments
 (0)