@@ -17,6 +17,7 @@ import (
1717 "ariga.io/atlas/sql/migrate"
1818 "ariga.io/atlas/sql/schema"
1919 "ariga.io/atlas/sql/sqlclient"
20+ "github.com/ydb-platform/ydb-go-sdk/v3"
2021 ydbSdk "github.com/ydb-platform/ydb-go-sdk/v3"
2122)
2223
@@ -59,6 +60,7 @@ func init() {
5960 sqlclient .Register (
6061 DriverName ,
6162 sqlclient .OpenerFunc (opener ),
63+ sqlclient .RegisterDriverOpener (Open ),
6264 sqlclient .RegisterURLParser (parser {}),
6365 )
6466}
@@ -81,7 +83,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
8183 }
8284
8385 sqlDriver := sql .OpenDB (conn )
84- migrateDriver , err := Open (nativeDriver , sqlDriver )
86+ migrateDriver , err := Open (sqlDriver )
8587 if err != nil {
8688 if cerr := sqlDriver .Close (); cerr != nil {
8789 err = fmt .Errorf ("%w: %v" , err , cerr )
@@ -102,14 +104,35 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
102104}
103105
104106// Open opens a new YDB driver.
105- func Open (nativeDriver * ydbSdk.Driver , sqlDriver * sql.DB ) (migrate.Driver , error ) {
107+ func Open (db schema.ExecQuerier ) (migrate.Driver , error ) {
108+ var (
109+ execQuerier schema.ExecQuerier
110+ nativeDriver * ydb.Driver
111+ err error
112+ )
113+
114+ switch casted := db .(type ) {
115+ case * sql.DB :
116+ nativeDriver , err = ydb .Unwrap (casted )
117+ execQuerier = casted
118+ case * sql.Conn :
119+ nativeDriver , err = ydb .Unwrap (casted )
120+ execQuerier = casted
121+ default :
122+ return nil , fmt .Errorf ("ydb: expected *sql.DB or *sql.Conn but got %T" , db )
123+ }
124+
125+ if err != nil {
126+ return nil , fmt .Errorf ("ydb: failed to unwrap ydb native driver: %v" , err )
127+ }
128+
106129 c := & conn {
107- ExecQuerier : sqlDriver ,
130+ ExecQuerier : execQuerier ,
108131 nativeDriver : nativeDriver ,
109132 database : nativeDriver .Name (),
110133 }
111134
112- rows , err := sqlDriver .QueryContext (context .Background (), "SELECT version()" )
135+ rows , err := execQuerier .QueryContext (context .Background (), "SELECT version()" )
113136 if err != nil {
114137 return nil , fmt .Errorf ("ydb: failed to query version: %w" , err )
115138 }
0 commit comments