Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 59 additions & 21 deletions dialect/sql/schema/atlas.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,35 +626,73 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
if len(plan.Changes) == 0 {
return nil
}
// Open a transaction for backwards compatibility,
// even if the migration is not transactional.
tx, err := a.sqlDialect.Tx(ctx)
if err != nil {
return err
}
a.atDriver, err = a.sqlDialect.atOpen(tx)
if err != nil {
return err
}
// Apply plan (changes).
var applier Applier = ApplyFunc(func(ctx context.Context, tx dialect.ExecQuerier, plan *migrate.Plan) error {
for _, c := range plan.Changes {
if err := tx.Exec(ctx, c.Cmd, c.Args, nil); err != nil {
if c.Comment != "" {
err = fmt.Errorf("%s: %w", c.Comment, err)

// YDB requires DDL operations to be executed outside of transactions.
if a.sqlDialect.Dialect() == dialect.YDB {
applier := ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
for _, change := range plan.Changes {
err := conn.Exec(
ctx,
change.Cmd,
change.Args,
nil,
)
if err != nil {
return wrapChangeError(change, err)
}
return err
}
return nil
})
if err := a.applyWithHooks(ctx, a.sqlDialect, plan, applier); err != nil {
return fmt.Errorf("sql/schema: %w", err)
}
return nil
})
} else {
// Open a transaction for backwards compatibility,
// even if the migration is not transactional.
tx, err := a.sqlDialect.Tx(ctx)
if err != nil {
return err
}
a.atDriver, err = a.sqlDialect.atOpen(tx)
if err != nil {
return err
}
applier := ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
for _, change := range plan.Changes {
if err := conn.Exec(ctx, change.Cmd, change.Args, nil); err != nil {
return wrapChangeError(change, err)
}
}
return nil
})
if err := a.applyWithHooks(ctx, tx, plan, applier); err != nil {
return errors.Join(fmt.Errorf("sql/schema: %w", err), tx.Rollback())
}
return tx.Commit()
}
}

// applyWithHooks wraps the given applier with the configured apply hooks and executes it.
func (a *Atlas) applyWithHooks(
ctx context.Context,
conn dialect.ExecQuerier,
plan *migrate.Plan,
base Applier,
) error {
applier := base
for i := len(a.applyHook) - 1; i >= 0; i-- {
applier = a.applyHook[i](applier)
}
if err = applier.Apply(ctx, tx, plan); err != nil {
return errors.Join(fmt.Errorf("sql/schema: %w", err), tx.Rollback())
return applier.Apply(ctx, conn, plan)
}

// wrapChangeError wraps an error with the change comment if present.
func wrapChangeError(c *migrate.Change, err error) error {
if c.Comment != "" {
return fmt.Errorf("%s: %w", c.Comment, err)
}
return tx.Commit()
return err
}

// For BC reason, we omit the schema qualifier from the migration plan.
Expand Down
40 changes: 28 additions & 12 deletions dialect/sql/schema/ydb.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ package schema

import (
"context"
"database/sql"
"errors"
"fmt"
"strings"

"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
entdriver "entgo.io/ent/dialect/ydb"
entdrv "entgo.io/ent/dialect/ydb"
"entgo.io/ent/schema/field"

"ariga.io/atlas/sql/migrate"
Expand All @@ -34,8 +33,8 @@ func (d *YDB) init(ctx context.Context) error {
return nil // already initialized.
}

rows := &sql.Rows{}
if err := d.Driver.Query(ctx, "SELECT version()", nil, rows); err != nil {
rows := &entsql.Rows{}
if err := d.Driver.Query(ctx, "SELECT version()", []any{}, rows); err != nil {
return fmt.Errorf("ydb: failed to query version: %w", err)
}
defer rows.Close()
Expand Down Expand Up @@ -69,9 +68,22 @@ func (d *YDB) tableExist(ctx context.Context, conn dialect.ExecQuerier, name str

// atOpen returns a custom Atlas migrate.Driver for YDB.
func (d *YDB) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
ydbDriver, ok := conn.(*entdriver.YDBDriver)
if !ok {
return nil, fmt.Errorf("expected dialect/ydb.YDBDriver, but got %T", conn)
var ydbDriver *entdrv.YDBDriver

switch drv := conn.(type) {
case *entdrv.YDBDriver:
ydbDriver = drv
case *YDB:
if ydb, ok := drv.Driver.(*entdrv.YDBDriver); ok {
ydbDriver = ydb
}
}
if ydbDriver == nil {
if ydb, ok := d.Driver.(*entdrv.YDBDriver); ok {
ydbDriver = ydb
} else {
return nil, fmt.Errorf("expected dialect/ydb.YDBDriver, but got %T", conn)
}
}

return atlas.Open(
Expand Down Expand Up @@ -145,11 +157,11 @@ func (d *YDB) atTypeC(column1 *Column, column2 *schema.Column) error {
case field.TypeString:
typ = &schema.StringType{T: atlas.TypeUtf8}
case field.TypeJSON:
typ = &schema.JSONType{T: atlas.TypeJson}
typ = &schema.JSONType{T: atlas.TypeJSON}
case field.TypeTime:
typ = &schema.TimeType{T: atlas.TypeTimestamp}
case field.TypeUUID:
typ = &schema.UUIDType{T: atlas.TypeUuid}
typ = &schema.UUIDType{T: atlas.TypeUUID}
case field.TypeEnum:
err = errors.New("ydb: Enum can't be used as column data type for tables")
case field.TypeOther:
Expand Down Expand Up @@ -186,7 +198,7 @@ func (d *YDB) atUniqueC(
index := schema.NewUniqueIndex(idxName).AddColumns(column2)

// Add YDB-specific attribute for GLOBAL SYNC index type.
index.AddAttrs(&atlas.YDBIndexAttributes{Global: true, Sync: true})
index.AddAttrs(&atlas.IndexAttributes{Global: true, Sync: true})

table2.AddIndexes(index)
}
Expand All @@ -195,7 +207,11 @@ func (d *YDB) atUniqueC(
// YDB uses Serial types for auto-increment.
func (d *YDB) atIncrementC(table *schema.Table, column *schema.Column) {
if intType, ok := column.Type.Type.(*schema.IntegerType); ok {
column.Type.Type = atlas.SerialFromInt(intType)
serial, err := atlas.SerialFromInt(intType)
if err != nil {
panic(err)
}
column.Type.Type = serial
}
}

Expand All @@ -220,7 +236,7 @@ func (d *YDB) atIndex(

// Set YDB-specific index attributes.
// By default, use GLOBAL SYNC for consistency.
idxType := &atlas.YDBIndexAttributes{Global: true, Sync: true}
idxType := &atlas.IndexAttributes{Global: true, Sync: true}

// Check for annotation overrides.
if index1.Annotation != nil {
Expand Down
Loading
Loading