@@ -17,6 +17,7 @@ import (
1717 "ariga.io/atlas/sql/migrate"
1818 "ariga.io/atlas/sql/schema"
1919 "ariga.io/atlas/sql/sqlclient"
20+ "github.com/ydb-platform/ydb-go-sdk/v3"
2021 ydbSdk "github.com/ydb-platform/ydb-go-sdk/v3"
2122)
2223
@@ -59,6 +60,7 @@ func init() {
5960 sqlclient .Register (
6061 DriverName ,
6162 sqlclient .OpenerFunc (opener ),
63+ sqlclient .RegisterDriverOpener (Open ),
6264 sqlclient .RegisterURLParser (parser {}),
6365 )
6466}
@@ -81,7 +83,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
8183 }
8284
8385 sqlDriver := sql .OpenDB (conn )
84- migrateDriver , err := Open (nativeDriver , sqlDriver )
86+ migrateDriver , err := Open (sqlDriver )
8587 if err != nil {
8688 if cerr := sqlDriver .Close (); cerr != nil {
8789 err = fmt .Errorf ("%w: %v" , err , cerr )
@@ -102,14 +104,32 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
102104}
103105
104106// Open opens a new YDB driver.
105- func Open (nativeDriver * ydbSdk.Driver , sqlDriver * sql.DB ) (migrate.Driver , error ) {
107+ func Open (db schema.ExecQuerier ) (migrate.Driver , error ) {
108+ var (
109+ nativeDriver * ydb.Driver
110+ err error
111+ )
112+
113+ switch casted := db .(type ) {
114+ case * sql.DB :
115+ nativeDriver , err = ydb .Unwrap (casted )
116+ case * sql.Conn :
117+ nativeDriver , err = ydb .Unwrap (casted )
118+ default :
119+ return nil , fmt .Errorf ("ydb: expected *sql.DB or *sql.Conn but got %T" , db )
120+ }
121+
122+ if err != nil {
123+ return nil , fmt .Errorf ("ydb: failed to unwrap ydb native driver: %v" , err )
124+ }
125+
106126 c := & conn {
107- ExecQuerier : sqlDriver ,
127+ ExecQuerier : db ,
108128 nativeDriver : nativeDriver ,
109129 database : nativeDriver .Name (),
110130 }
111131
112- rows , err := sqlDriver .QueryContext (context .Background (), "SELECT version()" )
132+ rows , err := db .QueryContext (context .Background (), "SELECT version()" )
113133 if err != nil {
114134 return nil , fmt .Errorf ("ydb: failed to query version: %w" , err )
115135 }
0 commit comments