Skip to content

Commit 9171e71

Browse files
sql/ydb: added support for modify table migrations (#2)
1 parent faf951f commit 9171e71

File tree

7 files changed

+793
-35
lines changed

7 files changed

+793
-35
lines changed

sql/ydb/attributes.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright 2021-present The Atlas Authors. All rights reserved.
2+
// This source code is licensed under the Apache 2.0 license found
3+
// in the LICENSE file in the root directory of this source tree.
4+
5+
//go:build !ent
6+
7+
package ydb
8+
9+
import "ariga.io/atlas/sql/schema"
10+
11+
//[IndexAttributes] represents YDB-specific index attributes.
12+
type IndexAttributes struct {
13+
schema.Attr
14+
Global bool // GLOBAL, LOCAL
15+
Sync bool // SYNC, ASYNC
16+
}

sql/ydb/convert.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ func FormatType(typ schema.Type) (string, error) {
4747
formatted = TypeUUID
4848
case *schema.TimeType:
4949
formatted, err = formatTimeType(t)
50+
case *schema.EnumType:
51+
err = errors.New("ydb: Enum can't be used as column data types for tables")
5052
case *schema.UnsupportedType:
5153
err = fmt.Errorf("ydb: unsupported type: %q", t.T)
5254
default:
@@ -145,7 +147,7 @@ func formatTimeType(t *schema.TimeType) (string, error) {
145147
// ParseType returns the schema.Type value represented by the given raw type.
146148
// The raw value is expected to follow the format of input for the CREATE TABLE statement.
147149
func ParseType(typ string) (schema.Type, error) {
148-
colDesc, err := parseColumn(typ)
150+
colDesc, err := parseColumn(strings.ToLower(typ))
149151
if err != nil {
150152
return nil, err
151153
}
@@ -203,9 +205,9 @@ func parseColumn(typ string) (*columnDecscriptor, error) {
203205
func parseOptionalType(typ string) (*columnDecscriptor, string) {
204206
colDesc := &columnDecscriptor{}
205207

206-
if strings.HasPrefix(typ, "Optional<") {
208+
if strings.HasPrefix(typ, "optional<") {
207209
colDesc.nullable = true
208-
typ = strings.TrimPrefix(typ, "Optional<")
210+
typ = strings.TrimPrefix(typ, "optional<")
209211
typ = strings.TrimSuffix(typ, ">")
210212
}
211213

@@ -249,7 +251,7 @@ func columnType(colDesc *columnDecscriptor) (schema.Type, error) {
249251
}
250252

251253
return &OptionalType{
252-
T: fmt.Sprintf("Optional<%s>", innerTypeStr),
254+
T: fmt.Sprintf("optional<%s>", innerTypeStr),
253255
InnerType: innerType,
254256
}, nil
255257
}

sql/ydb/convert_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ func TestConvert_ParseType(t *testing.T) {
188188
{name: "tztimestamp64", input: TypeTzTimestamp64, expected: &schema.TimeType{T: TypeTzTimestamp64}},
189189

190190
// Optional types
191-
{name: "optional_int32", input: "Optional<int32>", expected: &OptionalType{T: "Optional<int32>", InnerType: &schema.IntegerType{T: TypeInt32, Unsigned: false}}},
192-
{name: "optional_utf8", input: "Optional<utf8>", expected: &OptionalType{T: "Optional<utf8>", InnerType: &schema.StringType{T: TypeUtf8}}},
193-
{name: "optional_bool", input: "Optional<bool>", expected: &OptionalType{T: "Optional<bool>", InnerType: &schema.BoolType{T: TypeBool}}},
191+
{name: "optional_int32", input: "Optional<int32>", expected: &OptionalType{T: "optional<int32>", InnerType: &schema.IntegerType{T: TypeInt32, Unsigned: false}}},
192+
{name: "optional_utf8", input: "Optional<utf8>", expected: &OptionalType{T: "optional<utf8>", InnerType: &schema.StringType{T: TypeUtf8}}},
193+
{name: "optional_bool", input: "Optional<bool>", expected: &OptionalType{T: "optional<bool>", InnerType: &schema.BoolType{T: TypeBool}}},
194194

195195
// Unsupported/unknown types
196196
{name: "unknown_type", input: "unknown", expected: &schema.UnsupportedType{T: "unknown"}},

sql/ydb/driver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
8181
}
8282

8383
sqlDriver := sql.OpenDB(conn)
84-
drv, err := open(nativeDriver, sqlDriver)
84+
drv, err := Open(nativeDriver, sqlDriver)
8585
if err != nil {
8686
if cerr := sqlDriver.Close(); cerr != nil {
8787
err = fmt.Errorf("%w: %v", err, cerr)
@@ -102,7 +102,7 @@ func opener(ctx context.Context, dsn *url.URL) (*sqlclient.Client, error) {
102102
}
103103

104104
// Open opens a new YDB driver.
105-
func open(nativeDriver *ydbSdk.Driver, sqlDriver *sql.DB) (migrate.Driver, error) {
105+
func Open(nativeDriver *ydbSdk.Driver, sqlDriver *sql.DB) (migrate.Driver, error) {
106106
c := &conn{
107107
ExecQuerier: sqlDriver,
108108
nativeDriver: nativeDriver,

sql/ydb/migrate.go

Lines changed: 217 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ func (s *state) plan(changes []schema.Change) error {
8282
if err := s.dropTable(change); err != nil {
8383
return err
8484
}
85+
case *schema.ModifyTable:
86+
if err := s.modifyTable(change); err != nil {
87+
return err
88+
}
89+
case *schema.RenameTable:
90+
s.renameTable(change)
8591
default:
8692
return fmt.Errorf("ydb: unsupported change type: %T", change)
8793
}
@@ -92,10 +98,10 @@ func (s *state) plan(changes []schema.Change) error {
9298
// addTable builds and executes the query for creating a table in a schema.
9399
func (s *state) addTable(addTable *schema.AddTable) error {
94100
var errs []string
95-
b := s.Build("CREATE TABLE")
101+
builder := s.Build("CREATE TABLE")
96102

97-
b.Table(addTable.T)
98-
b.WrapIndent(func(b *sqlx.Builder) {
103+
builder.Table(addTable.T)
104+
builder.WrapIndent(func(b *sqlx.Builder) {
99105
b.MapIndent(addTable.T.Columns, func(i int, b *sqlx.Builder) {
100106
if err := s.column(b, addTable.T.Columns[i]); err != nil {
101107
errs = append(errs, err.Error())
@@ -123,20 +129,14 @@ func (s *state) addTable(addTable *schema.AddTable) error {
123129
String()
124130

125131
s.append(&migrate.Change{
126-
Cmd: b.String(),
132+
Cmd: builder.String(),
127133
Source: addTable,
128134
Comment: fmt.Sprintf("create %q table", addTable.T.Name),
129135
Reverse: reverse,
130136
})
131137
return nil
132138
}
133139

134-
// indexDef writes an inline index definition for CREATE TABLE.
135-
func (s *state) indexDef(b *sqlx.Builder, idx *schema.Index) {
136-
b.P("INDEX").Ident(idx.Name).P("GLOBAL ON")
137-
s.indexParts(b, idx.Parts)
138-
}
139-
140140
// dropTable builds and executes the query for dropping a table from a schema.
141141
func (s *state) dropTable(drop *schema.DropTable) error {
142142
reverseState := &state{
@@ -148,11 +148,11 @@ func (s *state) dropTable(drop *schema.DropTable) error {
148148
return fmt.Errorf("calculate reverse for drop table %q: %w", drop.T.Name, err)
149149
}
150150

151-
b := s.Build("DROP TABLE")
151+
builder := s.Build("DROP TABLE")
152152
if sqlx.Has(drop.Extra, &schema.IfExists{}) {
153-
b.P("IF EXISTS")
153+
builder.P("IF EXISTS")
154154
}
155-
b.Table(drop.T)
155+
builder.Table(drop.T)
156156

157157
// The reverse of 'DROP TABLE' might be a multi-statement operation
158158
reverse := func() any {
@@ -167,14 +167,211 @@ func (s *state) dropTable(drop *schema.DropTable) error {
167167
}()
168168

169169
s.append(&migrate.Change{
170-
Cmd: b.String(),
170+
Cmd: builder.String(),
171171
Source: drop,
172172
Comment: fmt.Sprintf("drop %q table", drop.T.Name),
173173
Reverse: reverse,
174174
})
175175
return nil
176176
}
177177

178+
// modifyTable builds the statements that bring the table into its modified state.
179+
func (s *state) modifyTable(modify *schema.ModifyTable) error {
180+
var (
181+
alterOps []schema.Change
182+
addIndexOps []*schema.AddIndex
183+
dropIndexOps []*schema.DropIndex
184+
)
185+
186+
for _, change := range modify.Changes {
187+
switch change := change.(type) {
188+
case *schema.AddColumn:
189+
alterOps = append(alterOps, change)
190+
191+
case *schema.DropColumn:
192+
alterOps = append(alterOps, change)
193+
194+
case *schema.AddIndex:
195+
addIndexOps = append(addIndexOps, change)
196+
197+
case *schema.DropIndex:
198+
dropIndexOps = append(dropIndexOps, change)
199+
200+
case *schema.ModifyIndex:
201+
// Index modification requires rebuilding the index.
202+
dropIndexOps = append(dropIndexOps, &schema.DropIndex{I: change.From})
203+
addIndexOps = append(addIndexOps, &schema.AddIndex{I: change.To})
204+
205+
case *schema.RenameIndex:
206+
s.renameIndex(modify, change)
207+
208+
default:
209+
return fmt.Errorf("ydb: unsupported table change: %T", change)
210+
}
211+
}
212+
213+
// Drop indexes first, then alter table, then add indexes
214+
if err := s.dropIndexes(modify, modify.T, dropIndexOps...); err != nil {
215+
return err
216+
}
217+
218+
if len(alterOps) > 0 {
219+
if err := s.alterTable(modify.T, alterOps); err != nil {
220+
return err
221+
}
222+
}
223+
224+
if err := s.addIndexes(modify, modify.T, addIndexOps...); err != nil {
225+
return err
226+
}
227+
228+
return nil
229+
}
230+
231+
// alterTable modifies the given table by executing on it a list of changes in one SQL statement.
232+
func (s *state) alterTable(t *schema.Table, changes []schema.Change) error {
233+
var reverse []schema.Change
234+
235+
buildFunc := func(changes []schema.Change) (string, error) {
236+
b := s.Build("ALTER TABLE").Table(t)
237+
238+
err := b.MapCommaErr(changes, func(i int, builder *sqlx.Builder) error {
239+
switch change := changes[i].(type) {
240+
case *schema.AddColumn:
241+
builder.P("ADD COLUMN")
242+
if err := s.column(builder, change.C); err != nil {
243+
return err
244+
}
245+
reverse = append(reverse, &schema.DropColumn{C: change.C})
246+
247+
case *schema.DropColumn:
248+
builder.P("DROP COLUMN").Ident(change.C.Name)
249+
reverse = append(reverse, &schema.AddColumn{C: change.C})
250+
}
251+
252+
return nil
253+
})
254+
if err != nil {
255+
return "", err
256+
}
257+
258+
return b.String(), nil
259+
}
260+
261+
stmt, err := buildFunc(changes)
262+
if err != nil {
263+
return fmt.Errorf("alter table %q: %v", t.Name, err)
264+
}
265+
266+
cmd := &migrate.Change{
267+
Cmd: stmt,
268+
Source: &schema.ModifyTable{
269+
T: t,
270+
Changes: changes,
271+
},
272+
Comment: fmt.Sprintf("modify %q table", t.Name),
273+
}
274+
275+
// Changes should be reverted in a reversed order they were created.
276+
sqlx.ReverseChanges(reverse)
277+
if cmd.Reverse, err = buildFunc(reverse); err != nil {
278+
return fmt.Errorf("reverse alter table %q: %v", t.Name, err)
279+
}
280+
281+
s.append(cmd)
282+
return nil
283+
}
284+
285+
func (s *state) addIndexes(src schema.Change, t *schema.Table, indexes ...*schema.AddIndex) error {
286+
for _, add := range indexes {
287+
index := add.I
288+
indexAttrs := IndexAttributes{}
289+
hasAttrs := sqlx.Has(index.Attrs, &indexAttrs)
290+
291+
b := s.Build("ALTER TABLE").
292+
Table(t).
293+
P("ADD INDEX").
294+
Ident(index.Name)
295+
296+
if hasAttrs && !indexAttrs.Global {
297+
b.P("LOCAL")
298+
} else {
299+
b.P("GLOBAL")
300+
}
301+
302+
if index.Unique {
303+
b.P("UNIQUE")
304+
}
305+
306+
if hasAttrs && !indexAttrs.Sync {
307+
b.P("ASYNC")
308+
} else {
309+
b.P("SYNC")
310+
}
311+
312+
b.P("ON")
313+
314+
s.indexParts(b, index.Parts)
315+
316+
reverseOp := s.Build("ALTER TABLE").
317+
Table(t).
318+
P("DROP INDEX").
319+
Ident(index.Name).
320+
String()
321+
322+
s.append(&migrate.Change{
323+
Cmd: b.String(),
324+
Source: src,
325+
Comment: fmt.Sprintf("create index %q to table: %q", index.Name, t.Name),
326+
Reverse: reverseOp,
327+
})
328+
}
329+
return nil
330+
}
331+
332+
func (s *state) dropIndexes(src schema.Change, t *schema.Table, drops ...*schema.DropIndex) error {
333+
adds := make([]*schema.AddIndex, len(drops))
334+
for i, d := range drops {
335+
adds[i] = &schema.AddIndex{I: d.I, Extra: d.Extra}
336+
}
337+
338+
reverseState := &state{conn: s.conn, PlanOptions: s.PlanOptions}
339+
if err := reverseState.addIndexes(src, t, adds...); err != nil {
340+
return err
341+
}
342+
343+
for i, add := range adds {
344+
s.append(&migrate.Change{
345+
Cmd: reverseState.Changes[i].Reverse.(string),
346+
Source: src,
347+
Comment: fmt.Sprintf("drop index %q from table: %q", add.I.Name, t.Name),
348+
Reverse: reverseState.Changes[i].Cmd,
349+
})
350+
}
351+
352+
return nil
353+
}
354+
355+
// renameTable builds and appends the statement for renaming a table.
356+
func (s *state) renameTable(c *schema.RenameTable) {
357+
s.append(&migrate.Change{
358+
Source: c,
359+
Comment: fmt.Sprintf("rename a table from %q to %q", c.From.Name, c.To.Name),
360+
Cmd: s.Build("ALTER TABLE").Table(c.From).P("RENAME TO").Table(c.To).String(),
361+
Reverse: s.Build("ALTER TABLE").Table(c.To).P("RENAME TO").Table(c.From).String(),
362+
})
363+
}
364+
365+
// renameIndex builds and appends the statement for renaming an index.
366+
func (s *state) renameIndex(modify *schema.ModifyTable, c *schema.RenameIndex) {
367+
s.append(&migrate.Change{
368+
Source: c,
369+
Comment: fmt.Sprintf("rename an index from %q to %q", c.From.Name, c.To.Name),
370+
Cmd: s.Build("ALTER TABLE").Table(modify.T).P("RENAME INDEX").Ident(c.From.Name).P("TO").Ident(c.To.Name).String(),
371+
Reverse: s.Build("ALTER TABLE").Table(modify.T).P("RENAME INDEX").Ident(c.To.Name).P("TO").Ident(c.From.Name).String(),
372+
})
373+
}
374+
178375
// column writes the column definition to the builder.
179376
func (s *state) column(b *sqlx.Builder, c *schema.Column) error {
180377
t, err := FormatType(c.Type.Type)
@@ -190,6 +387,12 @@ func (s *state) column(b *sqlx.Builder, c *schema.Column) error {
190387
return nil
191388
}
192389

390+
// indexDef writes an inline index definition for CREATE TABLE.
391+
func (s *state) indexDef(b *sqlx.Builder, idx *schema.Index) {
392+
b.P("INDEX").Ident(idx.Name).P("GLOBAL ON")
393+
s.indexParts(b, idx.Parts)
394+
}
395+
193396
// indexParts writes the index parts (columns) to the builder.
194397
func (s *state) indexParts(b *sqlx.Builder, parts []*schema.IndexPart) {
195398
b.Wrap(func(b *sqlx.Builder) {

0 commit comments

Comments
 (0)