Skip to content

Commit b1af6b1

Browse files
all: added support for ydb in codegen (#10)
1 parent c442e86 commit b1af6b1

File tree

179 files changed

+13143
-579
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

179 files changed

+13143
-579
lines changed

dialect/sql/schema/atlas.go

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -626,35 +626,73 @@ func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) {
626626
if len(plan.Changes) == 0 {
627627
return nil
628628
}
629-
// Open a transaction for backwards compatibility,
630-
// even if the migration is not transactional.
631-
tx, err := a.sqlDialect.Tx(ctx)
632-
if err != nil {
633-
return err
634-
}
635-
a.atDriver, err = a.sqlDialect.atOpen(tx)
636-
if err != nil {
637-
return err
638-
}
639-
// Apply plan (changes).
640-
var applier Applier = ApplyFunc(func(ctx context.Context, tx dialect.ExecQuerier, plan *migrate.Plan) error {
641-
for _, c := range plan.Changes {
642-
if err := tx.Exec(ctx, c.Cmd, c.Args, nil); err != nil {
643-
if c.Comment != "" {
644-
err = fmt.Errorf("%s: %w", c.Comment, err)
629+
630+
// YDB requires DDL operations to be executed outside of transactions.
631+
if a.sqlDialect.Dialect() == dialect.YDB {
632+
applier := ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
633+
for _, change := range plan.Changes {
634+
err := conn.Exec(
635+
ctx,
636+
change.Cmd,
637+
change.Args,
638+
nil,
639+
)
640+
if err != nil {
641+
return wrapChangeError(change, err)
645642
}
646-
return err
647643
}
644+
return nil
645+
})
646+
if err := a.applyWithHooks(ctx, a.sqlDialect, plan, applier); err != nil {
647+
return fmt.Errorf("sql/schema: %w", err)
648648
}
649649
return nil
650-
})
650+
} else {
651+
// Open a transaction for backwards compatibility,
652+
// even if the migration is not transactional.
653+
tx, err := a.sqlDialect.Tx(ctx)
654+
if err != nil {
655+
return err
656+
}
657+
a.atDriver, err = a.sqlDialect.atOpen(tx)
658+
if err != nil {
659+
return err
660+
}
661+
applier := ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
662+
for _, change := range plan.Changes {
663+
if err := conn.Exec(ctx, change.Cmd, change.Args, nil); err != nil {
664+
return wrapChangeError(change, err)
665+
}
666+
}
667+
return nil
668+
})
669+
if err := a.applyWithHooks(ctx, tx, plan, applier); err != nil {
670+
return errors.Join(fmt.Errorf("sql/schema: %w", err), tx.Rollback())
671+
}
672+
return tx.Commit()
673+
}
674+
}
675+
676+
// applyWithHooks wraps the given applier with the configured apply hooks and executes it.
677+
func (a *Atlas) applyWithHooks(
678+
ctx context.Context,
679+
conn dialect.ExecQuerier,
680+
plan *migrate.Plan,
681+
base Applier,
682+
) error {
683+
applier := base
651684
for i := len(a.applyHook) - 1; i >= 0; i-- {
652685
applier = a.applyHook[i](applier)
653686
}
654-
if err = applier.Apply(ctx, tx, plan); err != nil {
655-
return errors.Join(fmt.Errorf("sql/schema: %w", err), tx.Rollback())
687+
return applier.Apply(ctx, conn, plan)
688+
}
689+
690+
// wrapChangeError wraps an error with the change comment if present.
691+
func wrapChangeError(c *migrate.Change, err error) error {
692+
if c.Comment != "" {
693+
return fmt.Errorf("%s: %w", c.Comment, err)
656694
}
657-
return tx.Commit()
695+
return err
658696
}
659697

660698
// For BC reason, we omit the schema qualifier from the migration plan.

dialect/sql/schema/ydb.go

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@ package schema
66

77
import (
88
"context"
9-
"database/sql"
109
"errors"
1110
"fmt"
1211
"strings"
1312

1413
"entgo.io/ent/dialect"
1514
entsql "entgo.io/ent/dialect/sql"
16-
entdriver "entgo.io/ent/dialect/ydb"
15+
entdrv "entgo.io/ent/dialect/ydb"
1716
"entgo.io/ent/schema/field"
1817

1918
"ariga.io/atlas/sql/migrate"
@@ -34,8 +33,8 @@ func (d *YDB) init(ctx context.Context) error {
3433
return nil // already initialized.
3534
}
3635

37-
rows := &sql.Rows{}
38-
if err := d.Driver.Query(ctx, "SELECT version()", nil, rows); err != nil {
36+
rows := &entsql.Rows{}
37+
if err := d.Driver.Query(ctx, "SELECT version()", []any{}, rows); err != nil {
3938
return fmt.Errorf("ydb: failed to query version: %w", err)
4039
}
4140
defer rows.Close()
@@ -69,9 +68,22 @@ func (d *YDB) tableExist(ctx context.Context, conn dialect.ExecQuerier, name str
6968

7069
// atOpen returns a custom Atlas migrate.Driver for YDB.
7170
func (d *YDB) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) {
72-
ydbDriver, ok := conn.(*entdriver.YDBDriver)
73-
if !ok {
74-
return nil, fmt.Errorf("expected dialect/ydb.YDBDriver, but got %T", conn)
71+
var ydbDriver *entdrv.YDBDriver
72+
73+
switch drv := conn.(type) {
74+
case *entdrv.YDBDriver:
75+
ydbDriver = drv
76+
case *YDB:
77+
if ydb, ok := drv.Driver.(*entdrv.YDBDriver); ok {
78+
ydbDriver = ydb
79+
}
80+
}
81+
if ydbDriver == nil {
82+
if ydb, ok := d.Driver.(*entdrv.YDBDriver); ok {
83+
ydbDriver = ydb
84+
} else {
85+
return nil, fmt.Errorf("expected dialect/ydb.YDBDriver, but got %T", conn)
86+
}
7587
}
7688

7789
return atlas.Open(
@@ -145,11 +157,11 @@ func (d *YDB) atTypeC(column1 *Column, column2 *schema.Column) error {
145157
case field.TypeString:
146158
typ = &schema.StringType{T: atlas.TypeUtf8}
147159
case field.TypeJSON:
148-
typ = &schema.JSONType{T: atlas.TypeJson}
160+
typ = &schema.JSONType{T: atlas.TypeJSON}
149161
case field.TypeTime:
150162
typ = &schema.TimeType{T: atlas.TypeTimestamp}
151163
case field.TypeUUID:
152-
typ = &schema.UUIDType{T: atlas.TypeUuid}
164+
typ = &schema.UUIDType{T: atlas.TypeUUID}
153165
case field.TypeEnum:
154166
err = errors.New("ydb: Enum can't be used as column data type for tables")
155167
case field.TypeOther:
@@ -186,7 +198,7 @@ func (d *YDB) atUniqueC(
186198
index := schema.NewUniqueIndex(idxName).AddColumns(column2)
187199

188200
// Add YDB-specific attribute for GLOBAL SYNC index type.
189-
index.AddAttrs(&atlas.YDBIndexAttributes{Global: true, Sync: true})
201+
index.AddAttrs(&atlas.IndexAttributes{Global: true, Sync: true})
190202

191203
table2.AddIndexes(index)
192204
}
@@ -195,7 +207,11 @@ func (d *YDB) atUniqueC(
195207
// YDB uses Serial types for auto-increment.
196208
func (d *YDB) atIncrementC(table *schema.Table, column *schema.Column) {
197209
if intType, ok := column.Type.Type.(*schema.IntegerType); ok {
198-
column.Type.Type = atlas.SerialFromInt(intType)
210+
serial, err := atlas.SerialFromInt(intType)
211+
if err != nil {
212+
panic(err)
213+
}
214+
column.Type.Type = serial
199215
}
200216
}
201217

@@ -220,7 +236,7 @@ func (d *YDB) atIndex(
220236

221237
// Set YDB-specific index attributes.
222238
// By default, use GLOBAL SYNC for consistency.
223-
idxType := &atlas.YDBIndexAttributes{Global: true, Sync: true}
239+
idxType := &atlas.IndexAttributes{Global: true, Sync: true}
224240

225241
// Check for annotation overrides.
226242
if index1.Annotation != nil {

0 commit comments

Comments
 (0)