From 87d5f1b95075de24ed855d08b38d95f45fabfc36 Mon Sep 17 00:00:00 2001 From: danilov6083 Date: Wed, 4 Feb 2026 18:34:38 +0300 Subject: [PATCH] sql/ydb: tried to use driver unwrap from ydb-go-sdk --- sql/ydb/driver.go | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/ydb/driver.go b/sql/ydb/driver.go index b73e03753dc..8e8aafd84b4 100644 --- a/sql/ydb/driver.go +++ b/sql/ydb/driver.go @@ -17,6 +17,7 @@ import ( "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/schema" "ariga.io/atlas/sql/sqlclient" + "github.com/ydb-platform/ydb-go-sdk/v3" ydbSdk "github.com/ydb-platform/ydb-go-sdk/v3" ) @@ -59,6 +60,7 @@ func init() { sqlclient.Register( DriverName, sqlclient.OpenerFunc(opener), + sqlclient.RegisterDriverOpener(Open), sqlclient.RegisterURLParser(parser{}), ) } @@ -81,7 +83,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) { } sqlDriver := sql.OpenDB(conn) - migrateDriver, err := Open(nativeDriver, sqlDriver) + migrateDriver, err := Open(sqlDriver) if err != nil { if cerr := sqlDriver.Close(); cerr != nil { err = fmt.Errorf("%w: %v", err, cerr) @@ -102,14 +104,32 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) { } // Open opens a new YDB driver. -func Open(nativeDriver *ydbSdk.Driver, sqlDriver *sql.DB) (migrate.Driver, error) { +func Open(db schema.ExecQuerier) (migrate.Driver, error) { + var ( + nativeDriver *ydb.Driver + err error + ) + + switch casted := db.(type) { + case *sql.DB: + nativeDriver, err = ydb.Unwrap(casted) + case *sql.Conn: + nativeDriver, err = ydb.Unwrap(casted) + default: + return nil, fmt.Errorf("ydb: expected *sql.DB or *sql.Conn but got %T", db) + } + + if err != nil { + return nil, fmt.Errorf("ydb: failed to unwrap ydb native driver: %v", err) + } + c := &conn{ - ExecQuerier: sqlDriver, + ExecQuerier: db, nativeDriver: nativeDriver, database: nativeDriver.Name(), } - rows, err := sqlDriver.QueryContext(context.Background(), "SELECT version()") + rows, err := db.QueryContext(context.Background(), "SELECT version()") if err != nil { return nil, fmt.Errorf("ydb: failed to query version: %w", err) }