Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sql/ydb/attributes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

//go:build !ent

package ydb

import "ariga.io/atlas/sql/schema"

//[IndexAttributes] represents YDB-specific index attributes.
type IndexAttributes struct {
schema.Attr
Global bool // GLOBAL, LOCAL
Sync bool // SYNC, ASYNC
}
10 changes: 6 additions & 4 deletions sql/ydb/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func FormatType(typ schema.Type) (string, error) {
formatted = TypeUUID
case *schema.TimeType:
formatted, err = formatTimeType(t)
case *schema.EnumType:
err = errors.New("ydb: Enum can't be used as column data types for tables")
case *schema.UnsupportedType:
err = fmt.Errorf("ydb: unsupported type: %q", t.T)
default:
Expand Down Expand Up @@ -145,7 +147,7 @@ func formatTimeType(t *schema.TimeType) (string, error) {
// ParseType returns the schema.Type value represented by the given raw type.
// The raw value is expected to follow the format of input for the CREATE TABLE statement.
func ParseType(typ string) (schema.Type, error) {
colDesc, err := parseColumn(typ)
colDesc, err := parseColumn(strings.ToLower(typ))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -203,9 +205,9 @@ func parseColumn(typ string) (*columnDecscriptor, error) {
func parseOptionalType(typ string) (*columnDecscriptor, string) {
colDesc := &columnDecscriptor{}

if strings.HasPrefix(typ, "Optional<") {
if strings.HasPrefix(typ, "optional<") {
colDesc.nullable = true
typ = strings.TrimPrefix(typ, "Optional<")
typ = strings.TrimPrefix(typ, "optional<")
typ = strings.TrimSuffix(typ, ">")
}

Expand Down Expand Up @@ -249,7 +251,7 @@ func columnType(colDesc *columnDecscriptor) (schema.Type, error) {
}

return &OptionalType{
T: fmt.Sprintf("Optional<%s>", innerTypeStr),
T: fmt.Sprintf("optional<%s>", innerTypeStr),
InnerType: innerType,
}, nil
}
Expand Down
6 changes: 3 additions & 3 deletions sql/ydb/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ func TestConvert_ParseType(t *testing.T) {
{name: "tztimestamp64", input: TypeTzTimestamp64, expected: &schema.TimeType{T: TypeTzTimestamp64}},

// Optional types
{name: "optional_int32", input: "Optional<int32>", expected: &OptionalType{T: "Optional<int32>", InnerType: &schema.IntegerType{T: TypeInt32, Unsigned: false}}},
{name: "optional_utf8", input: "Optional<utf8>", expected: &OptionalType{T: "Optional<utf8>", InnerType: &schema.StringType{T: TypeUtf8}}},
{name: "optional_bool", input: "Optional<bool>", expected: &OptionalType{T: "Optional<bool>", InnerType: &schema.BoolType{T: TypeBool}}},
{name: "optional_int32", input: "Optional<int32>", expected: &OptionalType{T: "optional<int32>", InnerType: &schema.IntegerType{T: TypeInt32, Unsigned: false}}},
{name: "optional_utf8", input: "Optional<utf8>", expected: &OptionalType{T: "optional<utf8>", InnerType: &schema.StringType{T: TypeUtf8}}},
{name: "optional_bool", input: "Optional<bool>", expected: &OptionalType{T: "optional<bool>", InnerType: &schema.BoolType{T: TypeBool}}},

// Unsupported/unknown types
{name: "unknown_type", input: "unknown", expected: &schema.UnsupportedType{T: "unknown"}},
Expand Down
4 changes: 2 additions & 2 deletions sql/ydb/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
}

sqlDriver := sql.OpenDB(conn)
drv, err := open(nativeDriver, sqlDriver)
drv, err := Open(nativeDriver, sqlDriver)
if err != nil {
if cerr := sqlDriver.Close(); cerr != nil {
err = fmt.Errorf("%w: %v", err, cerr)
Expand All @@ -102,7 +102,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
}

// Open opens a new YDB driver.
func open(nativeDriver *ydbSdk.Driver, sqlDriver *sql.DB) (migrate.Driver, error) {
func Open(nativeDriver *ydbSdk.Driver, sqlDriver *sql.DB) (migrate.Driver, error) {
c := &conn{
ExecQuerier: sqlDriver,
nativeDriver: nativeDriver,
Expand Down
231 changes: 217 additions & 14 deletions sql/ydb/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ func (s *state) plan(changes []schema.Change) error {
if err := s.dropTable(change); err != nil {
return err
}
case *schema.ModifyTable:
if err := s.modifyTable(change); err != nil {
return err
}
case *schema.RenameTable:
s.renameTable(change)
default:
return fmt.Errorf("ydb: unsupported change type: %T", change)
}
Expand All @@ -92,10 +98,10 @@ func (s *state) plan(changes []schema.Change) error {
// addTable builds and executes the query for creating a table in a schema.
func (s *state) addTable(addTable *schema.AddTable) error {
var errs []string
b := s.Build("CREATE TABLE")
builder := s.Build("CREATE TABLE")

b.Table(addTable.T)
b.WrapIndent(func(b *sqlx.Builder) {
builder.Table(addTable.T)
builder.WrapIndent(func(b *sqlx.Builder) {
b.MapIndent(addTable.T.Columns, func(i int, b *sqlx.Builder) {
if err := s.column(b, addTable.T.Columns[i]); err != nil {
errs = append(errs, err.Error())
Expand Down Expand Up @@ -123,20 +129,14 @@ func (s *state) addTable(addTable *schema.AddTable) error {
String()

s.append(&migrate.Change{
Cmd: b.String(),
Cmd: builder.String(),
Source: addTable,
Comment: fmt.Sprintf("create %q table", addTable.T.Name),
Reverse: reverse,
})
return nil
}

// indexDef writes an inline index definition for CREATE TABLE.
func (s *state) indexDef(b *sqlx.Builder, idx *schema.Index) {
b.P("INDEX").Ident(idx.Name).P("GLOBAL ON")
s.indexParts(b, idx.Parts)
}

// dropTable builds and executes the query for dropping a table from a schema.
func (s *state) dropTable(drop *schema.DropTable) error {
reverseState := &state{
Expand All @@ -148,11 +148,11 @@ func (s *state) dropTable(drop *schema.DropTable) error {
return fmt.Errorf("calculate reverse for drop table %q: %w", drop.T.Name, err)
}

b := s.Build("DROP TABLE")
builder := s.Build("DROP TABLE")
if sqlx.Has(drop.Extra, &schema.IfExists{}) {
b.P("IF EXISTS")
builder.P("IF EXISTS")
}
b.Table(drop.T)
builder.Table(drop.T)

// The reverse of 'DROP TABLE' might be a multi-statement operation
reverse := func() any {
Expand All @@ -167,14 +167,211 @@ func (s *state) dropTable(drop *schema.DropTable) error {
}()

s.append(&migrate.Change{
Cmd: b.String(),
Cmd: builder.String(),
Source: drop,
Comment: fmt.Sprintf("drop %q table", drop.T.Name),
Reverse: reverse,
})
return nil
}

// modifyTable builds the statements that bring the table into its modified state.
func (s *state) modifyTable(modify *schema.ModifyTable) error {
var (
alterOps []schema.Change
addIndexOps []*schema.AddIndex
dropIndexOps []*schema.DropIndex
)

for _, change := range modify.Changes {
switch change := change.(type) {
case *schema.AddColumn:
alterOps = append(alterOps, change)

case *schema.DropColumn:
alterOps = append(alterOps, change)

case *schema.AddIndex:
addIndexOps = append(addIndexOps, change)

case *schema.DropIndex:
dropIndexOps = append(dropIndexOps, change)

case *schema.ModifyIndex:
// Index modification requires rebuilding the index.
dropIndexOps = append(dropIndexOps, &schema.DropIndex{I: change.From})
addIndexOps = append(addIndexOps, &schema.AddIndex{I: change.To})

case *schema.RenameIndex:
s.renameIndex(modify, change)

default:
return fmt.Errorf("ydb: unsupported table change: %T", change)
}
}

// Drop indexes first, then alter table, then add indexes
if err := s.dropIndexes(modify, modify.T, dropIndexOps...); err != nil {
return err
}

if len(alterOps) > 0 {
if err := s.alterTable(modify.T, alterOps); err != nil {
return err
}
}

if err := s.addIndexes(modify, modify.T, addIndexOps...); err != nil {
return err
}

return nil
}

// alterTable modifies the given table by executing on it a list of changes in one SQL statement.
func (s *state) alterTable(t *schema.Table, changes []schema.Change) error {
var reverse []schema.Change

buildFunc := func(changes []schema.Change) (string, error) {
b := s.Build("ALTER TABLE").Table(t)

err := b.MapCommaErr(changes, func(i int, builder *sqlx.Builder) error {
switch change := changes[i].(type) {
case *schema.AddColumn:
builder.P("ADD COLUMN")
if err := s.column(builder, change.C); err != nil {
return err
}
reverse = append(reverse, &schema.DropColumn{C: change.C})

case *schema.DropColumn:
builder.P("DROP COLUMN").Ident(change.C.Name)
reverse = append(reverse, &schema.AddColumn{C: change.C})
}

return nil
})
if err != nil {
return "", err
}

return b.String(), nil
}

stmt, err := buildFunc(changes)
if err != nil {
return fmt.Errorf("alter table %q: %v", t.Name, err)
}

cmd := &migrate.Change{
Cmd: stmt,
Source: &schema.ModifyTable{
T: t,
Changes: changes,
},
Comment: fmt.Sprintf("modify %q table", t.Name),
}

// Changes should be reverted in a reversed order they were created.
sqlx.ReverseChanges(reverse)
if cmd.Reverse, err = buildFunc(reverse); err != nil {
return fmt.Errorf("reverse alter table %q: %v", t.Name, err)
}

s.append(cmd)
return nil
}

func (s *state) addIndexes(src schema.Change, t *schema.Table, indexes ...*schema.AddIndex) error {
for _, add := range indexes {
index := add.I
indexAttrs := IndexAttributes{}
hasAttrs := sqlx.Has(index.Attrs, &indexAttrs)

b := s.Build("ALTER TABLE").
Table(t).
P("ADD INDEX").
Ident(index.Name)

if hasAttrs && !indexAttrs.Global {
b.P("LOCAL")
} else {
b.P("GLOBAL")
}

if index.Unique {
b.P("UNIQUE")
}

if hasAttrs && !indexAttrs.Sync {
b.P("ASYNC")
} else {
b.P("SYNC")
}

b.P("ON")

s.indexParts(b, index.Parts)

reverseOp := s.Build("ALTER TABLE").
Table(t).
P("DROP INDEX").
Ident(index.Name).
String()

s.append(&migrate.Change{
Cmd: b.String(),
Source: src,
Comment: fmt.Sprintf("create index %q to table: %q", index.Name, t.Name),
Reverse: reverseOp,
})
}
return nil
}

func (s *state) dropIndexes(src schema.Change, t *schema.Table, drops ...*schema.DropIndex) error {
adds := make([]*schema.AddIndex, len(drops))
for i, d := range drops {
adds[i] = &schema.AddIndex{I: d.I, Extra: d.Extra}
}

reverseState := &state{conn: s.conn, PlanOptions: s.PlanOptions}
if err := reverseState.addIndexes(src, t, adds...); err != nil {
return err
}

for i, add := range adds {
s.append(&migrate.Change{
Cmd: reverseState.Changes[i].Reverse.(string),
Source: src,
Comment: fmt.Sprintf("drop index %q from table: %q", add.I.Name, t.Name),
Reverse: reverseState.Changes[i].Cmd,
})
}

return nil
}

// renameTable builds and appends the statement for renaming a table.
func (s *state) renameTable(c *schema.RenameTable) {
s.append(&migrate.Change{
Source: c,
Comment: fmt.Sprintf("rename a table from %q to %q", c.From.Name, c.To.Name),
Cmd: s.Build("ALTER TABLE").Table(c.From).P("RENAME TO").Table(c.To).String(),
Reverse: s.Build("ALTER TABLE").Table(c.To).P("RENAME TO").Table(c.From).String(),
})
}

// renameIndex builds and appends the statement for renaming an index.
func (s *state) renameIndex(modify *schema.ModifyTable, c *schema.RenameIndex) {
s.append(&migrate.Change{
Source: c,
Comment: fmt.Sprintf("rename an index from %q to %q", c.From.Name, c.To.Name),
Cmd: s.Build("ALTER TABLE").Table(modify.T).P("RENAME INDEX").Ident(c.From.Name).P("TO").Ident(c.To.Name).String(),
Reverse: s.Build("ALTER TABLE").Table(modify.T).P("RENAME INDEX").Ident(c.To.Name).P("TO").Ident(c.From.Name).String(),
})
}

// column writes the column definition to the builder.
func (s *state) column(b *sqlx.Builder, c *schema.Column) error {
t, err := FormatType(c.Type.Type)
Expand All @@ -190,6 +387,12 @@ func (s *state) column(b *sqlx.Builder, c *schema.Column) error {
return nil
}

// indexDef writes an inline index definition for CREATE TABLE.
func (s *state) indexDef(b *sqlx.Builder, idx *schema.Index) {
b.P("INDEX").Ident(idx.Name).P("GLOBAL ON")
s.indexParts(b, idx.Parts)
}

// indexParts writes the index parts (columns) to the builder.
func (s *state) indexParts(b *sqlx.Builder, parts []*schema.IndexPart) {
b.Wrap(func(b *sqlx.Builder) {
Expand Down
Loading
Loading