|
| 1 | +// Copyright 2019-present Facebook Inc. All rights reserved. |
| 2 | +// This source code is licensed under the Apache 2.0 license found |
| 3 | +// in the LICENSE file in the root directory of this source tree. |
| 4 | + |
| 5 | +package schema |
| 6 | + |
| 7 | +import ( |
| 8 | + "context" |
| 9 | + "database/sql" |
| 10 | + "errors" |
| 11 | + "fmt" |
| 12 | + "strings" |
| 13 | + |
| 14 | + "entgo.io/ent/dialect" |
| 15 | + entsql "entgo.io/ent/dialect/sql" |
| 16 | + entdriver "entgo.io/ent/dialect/ydb" |
| 17 | + "entgo.io/ent/schema/field" |
| 18 | + |
| 19 | + "ariga.io/atlas/sql/migrate" |
| 20 | + "ariga.io/atlas/sql/schema" |
| 21 | + atlas "ariga.io/atlas/sql/ydb" |
| 22 | +) |
| 23 | + |
| 24 | +// YDB adapter for Atlas migration engine. |
| 25 | +type YDB struct { |
| 26 | + dialect.Driver |
| 27 | + |
| 28 | + version string |
| 29 | +} |
| 30 | + |
| 31 | +// init loads the YDB version from the database for later use in the migration process. |
| 32 | +func (d *YDB) init(ctx context.Context) error { |
| 33 | + if d.version != "" { |
| 34 | + return nil // already initialized. |
| 35 | + } |
| 36 | + |
| 37 | + rows := &sql.Rows{} |
| 38 | + if err := d.Driver.Query(ctx, "SELECT version()", nil, rows); err != nil { |
| 39 | + return fmt.Errorf("ydb: failed to query version: %w", err) |
| 40 | + } |
| 41 | + defer rows.Close() |
| 42 | + |
| 43 | + if !rows.Next() { |
| 44 | + if err := rows.Err(); err != nil { |
| 45 | + return err |
| 46 | + } |
| 47 | + return fmt.Errorf("ydb: version was not found") |
| 48 | + } |
| 49 | + |
| 50 | + var version string |
| 51 | + if err := rows.Scan(&version); err != nil { |
| 52 | + return fmt.Errorf("ydb: failed to scan version: %w", err) |
| 53 | + } |
| 54 | + |
| 55 | + d.version = version |
| 56 | + return nil |
| 57 | +} |
| 58 | + |
| 59 | +// tableExist checks if a table exists in the database by querying the .sys/tables system table. |
| 60 | +func (d *YDB) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) { |
| 61 | + query, args := entsql.Dialect(dialect.YDB). |
| 62 | + Select(entsql.Count("*")). |
| 63 | + From(entsql.Table(".sys/tables")). |
| 64 | + Where(entsql.EQ("table_name", name)). |
| 65 | + Query() |
| 66 | + |
| 67 | + return exist(ctx, conn, query, args...) |
| 68 | +} |
| 69 | + |
| 70 | +// atOpen returns a custom Atlas migrate.Driver for YDB. |
| 71 | +func (d *YDB) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { |
| 72 | + ydbDriver, ok := conn.(*entdriver.YDBDriver) |
| 73 | + if !ok { |
| 74 | + return nil, fmt.Errorf("expected dialect/ydb.YDBDriver, but got %T", conn) |
| 75 | + } |
| 76 | + |
| 77 | + return atlas.Open( |
| 78 | + ydbDriver.NativeDriver(), |
| 79 | + ydbDriver.DB(), |
| 80 | + ) |
| 81 | +} |
| 82 | + |
| 83 | +func (d *YDB) atTable(table1 *Table, table2 *schema.Table) { |
| 84 | + if table1.Annotation != nil { |
| 85 | + setAtChecks(table1, table2) |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +// supportsDefault returns whether YDB supports DEFAULT values for the given column type. |
| 90 | +func (d *YDB) supportsDefault(column *Column) bool { |
| 91 | + switch column.Default.(type) { |
| 92 | + case Expr, map[string]Expr: |
| 93 | + // Expression defaults are not well supported in YDB |
| 94 | + return false |
| 95 | + default: |
| 96 | + // Simple literal defaults should work for basic types |
| 97 | + return column.supportDefault() |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +// atTypeC converts an Ent column type to a YDB Atlas schema type. |
| 102 | +func (d *YDB) atTypeC(column1 *Column, column2 *schema.Column) error { |
| 103 | + // Check for custom schema type override. |
| 104 | + if column1.SchemaType != nil && column1.SchemaType[dialect.YDB] != "" { |
| 105 | + typ, err := atlas.ParseType( |
| 106 | + column1.SchemaType[dialect.YDB], |
| 107 | + ) |
| 108 | + if err != nil { |
| 109 | + return err |
| 110 | + } |
| 111 | + column2.Type.Type = typ |
| 112 | + return nil |
| 113 | + } |
| 114 | + |
| 115 | + var ( |
| 116 | + typ schema.Type |
| 117 | + err error |
| 118 | + ) |
| 119 | + |
| 120 | + switch column1.Type { |
| 121 | + case field.TypeBool: |
| 122 | + typ = &schema.BoolType{T: atlas.TypeBool} |
| 123 | + case field.TypeInt8: |
| 124 | + typ = &schema.IntegerType{T: atlas.TypeInt8} |
| 125 | + case field.TypeInt16: |
| 126 | + typ = &schema.IntegerType{T: atlas.TypeInt16} |
| 127 | + case field.TypeInt32: |
| 128 | + typ = &schema.IntegerType{T: atlas.TypeInt32} |
| 129 | + case field.TypeInt, field.TypeInt64: |
| 130 | + typ = &schema.IntegerType{T: atlas.TypeInt64} |
| 131 | + case field.TypeUint8: |
| 132 | + typ = &schema.IntegerType{T: atlas.TypeUint8, Unsigned: true} |
| 133 | + case field.TypeUint16: |
| 134 | + typ = &schema.IntegerType{T: atlas.TypeUint16, Unsigned: true} |
| 135 | + case field.TypeUint32: |
| 136 | + typ = &schema.IntegerType{T: atlas.TypeUint32, Unsigned: true} |
| 137 | + case field.TypeUint, field.TypeUint64: |
| 138 | + typ = &schema.IntegerType{T: atlas.TypeUint64, Unsigned: true} |
| 139 | + case field.TypeFloat32: |
| 140 | + typ = &schema.FloatType{T: atlas.TypeFloat} |
| 141 | + case field.TypeFloat64: |
| 142 | + typ = &schema.FloatType{T: atlas.TypeDouble} |
| 143 | + case field.TypeBytes: |
| 144 | + typ = &schema.BinaryType{T: atlas.TypeString} |
| 145 | + case field.TypeString: |
| 146 | + typ = &schema.StringType{T: atlas.TypeUtf8} |
| 147 | + case field.TypeJSON: |
| 148 | + typ = &schema.JSONType{T: atlas.TypeJson} |
| 149 | + case field.TypeTime: |
| 150 | + typ = &schema.TimeType{T: atlas.TypeTimestamp} |
| 151 | + case field.TypeUUID: |
| 152 | + typ = &schema.UUIDType{T: atlas.TypeUuid} |
| 153 | + case field.TypeEnum: |
| 154 | + err = errors.New("ydb: Enum can't be used as column data type for tables") |
| 155 | + case field.TypeOther: |
| 156 | + typ = &schema.UnsupportedType{T: column1.typ} |
| 157 | + default: |
| 158 | + typ, err = atlas.ParseType(column1.typ) |
| 159 | + } |
| 160 | + |
| 161 | + if err != nil { |
| 162 | + return err |
| 163 | + } |
| 164 | + |
| 165 | + column2.Type.Type = typ |
| 166 | + return nil |
| 167 | +} |
| 168 | + |
| 169 | +// atUniqueC adds a unique constraint for a column. |
| 170 | +// In YDB, unique constraints are implemented as GLOBAL UNIQUE SYNC indexes. |
| 171 | +func (d *YDB) atUniqueC( |
| 172 | + table1 *Table, |
| 173 | + column1 *Column, |
| 174 | + table2 *schema.Table, |
| 175 | + column2 *schema.Column, |
| 176 | +) { |
| 177 | + // Check if there's already an explicit unique index defined for this column. |
| 178 | + for _, idx := range table1.Indexes { |
| 179 | + if idx.Unique && len(idx.Columns) == 1 && idx.Columns[0].Name == column1.Name { |
| 180 | + // Index already defined explicitly, will be added in atIndexes. |
| 181 | + return |
| 182 | + } |
| 183 | + } |
| 184 | + // Create a unique index for this column. |
| 185 | + idxName := fmt.Sprintf("%s_%s_index", table1.Name, column1.Name) |
| 186 | + index := schema.NewUniqueIndex(idxName).AddColumns(column2) |
| 187 | + |
| 188 | + // Add YDB-specific attribute for GLOBAL SYNC index type. |
| 189 | + index.AddAttrs(&atlas.YDBIndexAttributes{Global: true, Sync: true}) |
| 190 | + |
| 191 | + table2.AddIndexes(index) |
| 192 | +} |
| 193 | + |
| 194 | +// atIncrementC configures auto-increment for a column. |
| 195 | +// YDB uses Serial types for auto-increment. |
| 196 | +func (d *YDB) atIncrementC(table *schema.Table, column *schema.Column) { |
| 197 | + if intType, ok := column.Type.Type.(*schema.IntegerType); ok { |
| 198 | + column.Type.Type = atlas.SerialFromInt(intType) |
| 199 | + } |
| 200 | +} |
| 201 | + |
| 202 | +// atIncrementT sets the table-level auto-increment starting value. |
| 203 | +func (d *YDB) atIncrementT(table *schema.Table, v int64) { |
| 204 | + // not implemented |
| 205 | +} |
| 206 | + |
| 207 | +// atIndex configures an index for ydb. |
| 208 | +func (d *YDB) atIndex( |
| 209 | + index1 *Index, |
| 210 | + table2 *schema.Table, |
| 211 | + index2 *schema.Index, |
| 212 | +) error { |
| 213 | + for _, column1 := range index1.Columns { |
| 214 | + column2, ok := table2.Column(column1.Name) |
| 215 | + if !ok { |
| 216 | + return fmt.Errorf("unexpected index %q column: %q", index1.Name, column1.Name) |
| 217 | + } |
| 218 | + index2.AddParts(&schema.IndexPart{C: column2}) |
| 219 | + } |
| 220 | + |
| 221 | + // Set YDB-specific index attributes. |
| 222 | + // By default, use GLOBAL SYNC for consistency. |
| 223 | + idxType := &atlas.YDBIndexAttributes{Global: true, Sync: true} |
| 224 | + |
| 225 | + // Check for annotation overrides. |
| 226 | + if index1.Annotation != nil { |
| 227 | + if indexType, ok := indexType(index1, dialect.YDB); ok { |
| 228 | + // Parse YDB-specific index type from annotation. |
| 229 | + switch strings.ToUpper(indexType) { |
| 230 | + case "GLOBAL ASYNC", "ASYNC": |
| 231 | + idxType.Sync = false |
| 232 | + case "LOCAL": |
| 233 | + idxType.Global = false |
| 234 | + case "LOCAL ASYNC": |
| 235 | + idxType.Global = false |
| 236 | + idxType.Sync = false |
| 237 | + } |
| 238 | + } |
| 239 | + } |
| 240 | + index2.AddAttrs(idxType) |
| 241 | + return nil |
| 242 | +} |
| 243 | + |
| 244 | +// atTypeRangeSQL returns the SQL statement to insert type ranges for global unique IDs. |
| 245 | +func (*YDB) atTypeRangeSQL(ts ...string) string { |
| 246 | + values := make([]string, len(ts)) |
| 247 | + for i, t := range ts { |
| 248 | + values[i] = fmt.Sprintf("('%s')", t) |
| 249 | + } |
| 250 | + return fmt.Sprintf( |
| 251 | + "UPSERT INTO `%s` (`type`) VALUES %s", |
| 252 | + TypeTable, |
| 253 | + strings.Join(values, ", "), |
| 254 | + ) |
| 255 | +} |
0 commit comments