diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5c09c7285b..06ca3317b7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -274,6 +274,20 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 + ydb: + image: ydbplatform/local-ydb:trunk + env: + GRPC_TLS_PORT: 2135 + GRPC_PORT: 2136 + MON_PORT: 8765 + ports: + - 2136:2136 + - 8765:8765 + options: >- + --hostname localhost + --platform linux/amd64 + --health-cmd "true" + --health-start-period 30s steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 @@ -476,6 +490,20 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 + ydb: + image: ydbplatform/local-ydb:trunk + env: + GRPC_TLS_PORT: 2135 + GRPC_PORT: 2136 + MON_PORT: 8765 + ports: + - 2136:2136 + - 8765:8765 + options: >- + --hostname localhost + --platform linux/amd64 + --health-cmd "true" + --health-start-period 30s steps: - uses: actions/checkout@v4 with: diff --git a/dialect/entsql/annotation.go b/dialect/entsql/annotation.go index aae8a420ee..b0db7fbaf0 100644 --- a/dialect/entsql/annotation.go +++ b/dialect/entsql/annotation.go @@ -550,15 +550,19 @@ type IndexAnnotation struct { DescColumns map[string]bool // IncludeColumns defines the INCLUDE clause for the index. - // Works only in Postgres and its definition is as follows: + // Works only in Postgres and YDB. Its definition is as follows: // // index.Fields("c1"). // Annotation( // entsql.IncludeColumns("c2"), // ) // + // Postgres: // CREATE INDEX "table_column" ON "table"("c1") INCLUDE ("c2") // + // YDB: + // ALTER TABLE `table` ADD INDEX `table_column` GLOBAL SYNC ON (`c1`) COVER (`c2`) + // IncludeColumns []string // Type defines the type of the index. diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 72f5d17aba..0a3264e468 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -16,10 +16,13 @@ import ( "database/sql/driver" "errors" "fmt" + "reflect" "strconv" "strings" "entgo.io/ent/dialect" + + "github.com/google/uuid" ) // Querier wraps the basic Query method that is implemented @@ -854,8 +857,7 @@ func Delete(table string) *DeleteBuilder { return &DeleteBuilder{table: table} } // // Note: BATCH DELETE is only supported in YDB dialect. // -// BatchDelete("/local/my_table"). -// Where(GT("Key1", 1)) +// BatchDelete("/local/my_table").Where(GT("Key1", 1)) func BatchDelete(table string) *DeleteBuilder { return &DeleteBuilder{table: table, isBatch: true} } @@ -1377,38 +1379,64 @@ func (p *Predicate) Like(col, pattern string) *Predicate { }) } -// escape escapes w with the default escape character ('/'), +// escape escapes word with the default escape character ('\'), // to be used by the pattern matching functions below. // The second return value indicates if w was escaped or not. -func escape(w string) (string, bool) { +func escape(word string) (string, bool) { + return escapeWith(word, '\\') +} + +// escapeYDB escapes w with '#' for YDB, since YDB doesn't support '\' in ESCAPE clause. +func escapeYDB(word string) (string, bool) { + return escapeWith(word, '#') +} + +func escapeWith(word string, escChar byte) (string, bool) { var n int - for i := range w { - if c := w[i]; c == '%' || c == '_' || c == '\\' { + for i := 0; i < len(word); i++ { + if ch := word[i]; ch == '%' || ch == '_' || ch == escChar { n++ } } // No characters to escape. if n == 0 { - return w, false + return word, false } - var b strings.Builder - b.Grow(len(w) + n) - for _, c := range w { - if c == '%' || c == '_' || c == '\\' { - b.WriteByte('\\') + + var builder strings.Builder + builder.Grow(len(word) + n) + + for i := 0; i < len(word); i++ { + if ch := word[i]; ch == '%' || ch == '_' || ch == escChar { + builder.WriteByte(escChar) } - b.WriteRune(c) + builder.WriteByte(word[i]) } - return b.String(), true + return builder.String(), true } func (p *Predicate) escapedLike(col, left, right, word string) *Predicate { - return p.Append(func(b *Builder) { - w, escaped := escape(word) - b.Ident(col).WriteOp(OpLike) - b.Arg(left + w + right) - if p.dialect == dialect.SQLite && escaped { - p.WriteString(" ESCAPE ").Arg("\\") + return p.Append(func(builder *Builder) { + var escapedWord string + var escaped bool + + if p.dialect == dialect.YDB { + escapedWord, escaped = escapeYDB(word) + } else { + escapedWord, escaped = escape(word) + } + + builder.Ident(col).WriteOp(OpLike) + builder.Arg(left + escapedWord + right) + + // SQLite and YDB require explicit ESCAPE clause. + if escaped { + switch p.dialect { + case dialect.SQLite: + p.WriteString(" ESCAPE ").Arg("\\") + case dialect.YDB: + p.WriteString(" ESCAPE '#'") + } } }) } @@ -1416,17 +1444,26 @@ func (p *Predicate) escapedLike(col, left, right, word string) *Predicate { // ContainsFold is a helper predicate that applies the LIKE predicate with case-folding. func (p *Predicate) escapedLikeFold(col, left, substr, right string) *Predicate { return p.Append(func(b *Builder) { - w, escaped := escape(substr) switch b.dialect { case dialect.MySQL: + w, _ := escape(substr) // We assume the CHARACTER SET is configured to utf8mb4, // because this how it is defined in dialect/sql/schema. b.Ident(col).WriteString(" COLLATE utf8mb4_general_ci LIKE ") b.Arg(left + strings.ToLower(w) + right) - case dialect.Postgres, dialect.YDB: + case dialect.Postgres: + w, _ := escape(substr) b.Ident(col).WriteString(" ILIKE ") b.Arg(left + strings.ToLower(w) + right) + case dialect.YDB: + w, escaped := escapeYDB(substr) + b.Ident(col).WriteString(" ILIKE ") + b.Arg(left + strings.ToLower(w) + right) + if escaped { + p.WriteString(" ESCAPE '#'") + } default: // SQLite. + w, escaped := escape(substr) var f Func f.SetDialect(b.dialect) f.Lower(col) @@ -1479,6 +1516,12 @@ func (p *Predicate) ColumnsHasPrefix(col, prefixC string) *Predicate { if p.dialect == dialect.SQLite { p.WriteString(" ESCAPE ").Arg("\\") } + case dialect.YDB: + b.S("StartsWith("). + Ident(col). + S(", "). + Ident(prefixC). + S(")") default: b.AddError(fmt.Errorf("ColumnsHasPrefix is not supported by %q", p.dialect)) } @@ -1660,7 +1703,54 @@ func Lower(ident string) string { // Lower wraps the given ident with the LOWER function. func (f *Func) Lower(ident string) { - f.byName("LOWER", ident) + f.Append(func(b *Builder) { + if f.dialect == dialect.YDB { + f.WriteString("Unicode::ToLower(") + b.Ident(ident) + f.WriteString(")") + } else { + f.WriteString("LOWER") + f.Wrap(func(b *Builder) { + b.Ident(ident) + }) + } + }) +} + +// Upper wraps the given column with the UPPER function. +// +// P().EQ(sql.Upper("name"), "A8M") +func Upper(ident string) string { + f := &Func{} + f.Upper(ident) + return f.String() +} + +// Upper wraps the given ident with the UPPER function. +func (f *Func) Upper(ident string) { + f.Append(func(b *Builder) { + if f.dialect == dialect.YDB { + f.WriteString("Unicode::ToUpper(") + b.Ident(ident) + f.WriteString(")") + } else { + f.WriteString("UPPER") + f.Wrap(func(b *Builder) { + b.Ident(ident) + }) + } + }) +} + +// UpperExpr returns a dialect-aware UPPER expression. +func UpperExpr(ident string) Querier { + return ExprFunc(func(b *Builder) { + if b.Dialect() == dialect.YDB { + b.WriteString("Unicode::ToUpper(").Ident(ident).WriteString(")") + } else { + b.WriteString("UPPER(").Ident(ident).WriteString(")") + } + }) } // Count wraps the ident with the COUNT aggregation function. @@ -1795,7 +1885,10 @@ type SelectTable struct { name string schema string quote bool - index string // YDB-specific: secondary index name for VIEW clause + + // YDB-specific: + index string // secondary index name for VIEW clause + isCte bool // YDB-specific: marks this as a CTE reference } // Table returns a new table selector. @@ -1870,6 +1963,23 @@ func (s *SelectTable) ref() string { } b := &Builder{dialect: s.dialect} b.writeSchema(s.schema) + + // YDB-specific: CTE references require $ prefix + // and should have alias for easy handling + if s.isCte && b.ydb() { + b.WriteString("$") + b.Ident(s.name) + + b.WriteString(" AS ") + if s.as != "" { + b.Ident(s.as) + } else { + b.Ident(s.name) + } + + return b.String() + } + b.Ident(s.name) // YDB-specific: VIEW clause for secondary indexes @@ -1911,7 +2021,7 @@ type Selector struct { // generated code such as alternate table schemas. ctx context.Context as string - selection []selection + selection []*selection from []TableView joins []join collected [][]*Predicate @@ -1977,17 +2087,17 @@ func SelectExpr(exprs ...Querier) *Selector { // selection represents a column or an expression selection. type selection struct { - x Querier - c string - as string + expr Querier + column string + alias string } // Select changes the columns selection of the SELECT statement. // Empty selection means all columns *. func (s *Selector) Select(columns ...string) *Selector { - s.selection = make([]selection, len(columns)) + s.selection = make([]*selection, len(columns)) for i := range columns { - s.selection[i] = selection{c: columns[i]} + s.selection[i] = &selection{column: columns[i]} } return s } @@ -2000,23 +2110,23 @@ func (s *Selector) SelectDistinct(columns ...string) *Selector { // AppendSelect appends additional columns to the SELECT statement. func (s *Selector) AppendSelect(columns ...string) *Selector { for i := range columns { - s.selection = append(s.selection, selection{c: columns[i]}) + s.selection = append(s.selection, &selection{column: columns[i]}) } return s } // AppendSelectAs appends additional column to the SELECT statement with the given alias. func (s *Selector) AppendSelectAs(column, as string) *Selector { - s.selection = append(s.selection, selection{c: column, as: as}) + s.selection = append(s.selection, &selection{column: column, alias: as}) return s } // SelectExpr changes the columns selection of the SELECT statement // with custom list of expressions. func (s *Selector) SelectExpr(exprs ...Querier) *Selector { - s.selection = make([]selection, len(exprs)) + s.selection = make([]*selection, len(exprs)) for i := range exprs { - s.selection[i] = selection{x: exprs[i]} + s.selection[i] = &selection{expr: exprs[i]} } return s } @@ -2024,7 +2134,7 @@ func (s *Selector) SelectExpr(exprs ...Querier) *Selector { // AppendSelectExpr appends additional expressions to the SELECT statement. func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector { for i := range exprs { - s.selection = append(s.selection, selection{x: exprs[i]}) + s.selection = append(s.selection, &selection{expr: exprs[i]}) } return s } @@ -2037,9 +2147,9 @@ func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector { b.S("(").Join(expr).S(")") }) } - s.selection = append(s.selection, selection{ - x: x, - as: as, + s.selection = append(s.selection, &selection{ + expr: x, + alias: as, }) return s } @@ -2067,16 +2177,16 @@ func (s *Selector) FindSelection(name string) (matches []string) { for _, c := range s.selection { switch { // Match aliases. - case c.as != "": - if ident := s.isIdent(c.as); !ident && c.as == name || ident && s.unquote(c.as) == name { - matches = append(matches, c.as) + case c.alias != "": + if ident := s.isIdent(c.alias); !ident && c.alias == name || ident && s.unquote(c.alias) == name { + matches = append(matches, c.alias) } // Match qualified columns. - case c.c != "" && s.isQualified(c.c) && matchC(c.c): - matches = append(matches, c.c) + case c.column != "" && s.isQualified(c.column) && matchC(c.column): + matches = append(matches, c.column) // Match unqualified columns. - case c.c != "" && (c.c == name || s.isIdent(c.c) && s.unquote(c.c) == name): - matches = append(matches, c.c) + case c.column != "" && (c.column == name || s.isIdent(c.column) && s.unquote(c.column) == name): + matches = append(matches, c.column) } } return matches @@ -2086,7 +2196,7 @@ func (s *Selector) FindSelection(name string) (matches []string) { func (s *Selector) SelectedColumns() []string { columns := make([]string, 0, len(s.selection)) for i := range s.selection { - if c := s.selection[i].c; c != "" { + if c := s.selection[i].column; c != "" { columns = append(columns, c) } } @@ -2098,7 +2208,7 @@ func (s *Selector) SelectedColumns() []string { func (s *Selector) UnqualifiedColumns() []string { columns := make([]string, 0, len(s.selection)) for i := range s.selection { - c := s.selection[i].c + c := s.selection[i].column if c == "" { continue } @@ -2689,7 +2799,7 @@ func (s *Selector) Clone() *Selector { group: append([]string{}, s.group...), order: append([]any{}, s.order...), assumeOrder: append([]string{}, s.assumeOrder...), - selection: append([]selection{}, s.selection...), + selection: append([]*selection{}, s.selection...), } } @@ -2799,6 +2909,10 @@ func (s *Selector) Having(p *Predicate) *Selector { // Query returns query representation of a `SELECT` statement. func (s *Selector) Query() (string, []any) { b := s.Builder.clone() + // For YDB, mark tables that reference CTEs + if b.ydb() { + s.markCteReferences() + } s.joinPrefix(&b) b.WriteString("SELECT ") if s.distinct { @@ -2873,6 +2987,9 @@ func (s *Selector) Query() (string, []any) { if len(s.setOps) > 0 { s.joinSetOps(&b) } + if b.ydb() { + s.applyAliasesToOrder() + } joinOrder(s.order, &b) s.joinAssumeOrder(&b) if s.limit != nil { @@ -2896,6 +3013,45 @@ func (s *Selector) joinPrefix(b *Builder) { } } +// markCteReferences marks SelectTable entries in from/joins as CTE references +// if they match CTE names from the prefix. Used for YDB which requires $ prefix +// when referencing named expressions (CTEs). +func (s *Selector) markCteReferences() { + cteNames := s.collectCteNames() + if len(cteNames) == 0 { + return + } + // Mark FROM tables + for _, from := range s.from { + if table, ok := from.(*SelectTable); ok { + if _, isCte := cteNames[table.name]; isCte { + table.isCte = true + } + } + } + // Mark JOIN tables + for _, join := range s.joins { + if table, ok := join.table.(*SelectTable); ok { + if _, isCte := cteNames[table.name]; isCte { + table.isCte = true + } + } + } +} + +// collectCteNames returns a set of CTE names from the selector's prefix. +func (s *Selector) collectCteNames() map[string]any { + names := make(map[string]any) + for _, prefix := range s.prefix { + if with, ok := prefix.(*WithBuilder); ok { + for _, cte := range with.ctes { + names[cte.name] = struct{}{} + } + } + } + return names +} + func (s *Selector) joinLock(b *Builder) { if s.lock == nil { return @@ -2976,21 +3132,145 @@ func joinReturning(columns []string, b *Builder) { } func (s *Selector) joinSelect(b *Builder) { - for i, sc := range s.selection { + for i, selection := range s.selection { if i > 0 { b.Comma() } + + // YDB returns column names with table prefix (e.g., "users.name" instead of "name"), + // so we add aliases to ensure the scanner can match columns correctly. + if selection.alias == "" && b.ydb() { + // If the column already has an alias, extract it + upperColumnName := strings.ToUpper(selection.column) + + if idx := strings.LastIndex(upperColumnName, " AS "); idx != -1 { + originalColumn := selection.column + + selection.column = strings.TrimSpace(originalColumn[:idx]) + selection.alias = strings.Trim(originalColumn[idx+4:], " `\"") + } else if selection.column != "" && !strings.ContainsAny(selection.column, "()") { + // Qualified column name like "users.name" -> alias "name" + if idx := strings.LastIndexByte(selection.column, '.'); idx != -1 { + selection.alias = selection.column[idx+1:] + } + } else if selection.column != "" { + // Expression passed as column string like "COUNT(*)" or "SUM(users.age)" + selection.alias = exprAlias(selection.column) + } + } + switch { - case sc.c != "": - b.Ident(sc.c) - case sc.x != nil: - b.Join(sc.x) + case selection.column != "": + // YDB requires qualified asterisk (table.*) + // when mixing * with other projection items. + if b.ydb() && selection.column == "*" && len(s.selection) > 1 { + if tableName := s.firstTableName(); tableName != "" { + b.Ident(tableName).WriteByte('.').WriteString("*") + } else { + b.Ident(selection.column) + } + } else { + b.Ident(selection.column) + } + case selection.expr != nil: + b.Join(selection.expr) } - if sc.as != "" { + + if selection.alias != "" { b.WriteString(" AS ") - b.Ident(sc.as) + b.Ident(selection.alias) + } + } +} + +// firstTableName returns the name or alias of the first table in the FROM clause. +func (s *Selector) firstTableName() string { + if len(s.from) == 0 { + return "" + } + switch t := s.from[0].(type) { + case *SelectTable: + if t.as != "" { + return t.as + } + return t.name + case *WithBuilder: + return t.Name() + case *Selector: + return t.as + } + return "" +} + +// applyAliasesToOrder processes ORDER BY columns for YDB compatibility. +// - When there's a GROUP BY, use aliases +// - When there are subquery joins, use aliases +// - When there are simple table joins, keep qualified columns to avoid ambiguity +// - Otherwise, replace qualified columns with their aliases +func (s *Selector) applyAliasesToOrder() { + aliasMap := make(map[string]string) + for _, selection := range s.selection { + if selection.column == "" { + continue + } + if selection.alias != "" { + aliasMap[selection.column] = selection.alias + } + } + + if len(aliasMap) == 0 { + return + } + + hasGroupBy := len(s.group) > 0 + hasSubqueryJoin := s.hasSubqueryJoin() + hasSimpleTableJoin := len(s.joins) > 0 && !hasSubqueryJoin + + // Process ORDER BY columns + result := make([]any, len(s.order)) + for i, order := range s.order { + str, ok := order.(string) + if !ok { + result[i] = order + continue + } + // Handle "column DESC" or "column ASC" suffixes. + column, suffix := str, "" + if idx := strings.LastIndex(str, " "); idx != -1 { + upper := strings.ToUpper(str[idx+1:]) + if upper == "ASC" || upper == "DESC" { + column = str[:idx] + suffix = str[idx:] + } + } + + if alias, ok := aliasMap[column]; (hasGroupBy || hasSubqueryJoin || !hasSimpleTableJoin) && ok { + result[i] = alias + suffix + } else { + result[i] = order } } + s.order = result +} + +// hasSubqueryJoin returns true if any join involves a subquery (Selector). +func (s *Selector) hasSubqueryJoin() bool { + for _, join := range s.joins { + if _, ok := join.table.(*Selector); ok { + return true + } + } + return false +} + +// exprAlias extracts an alias from an aggregate expression for YDB. +// E.g., "COUNT(*)" -> "count", "SUM(users.age)" -> "sum" +func exprAlias(expr string) string { + expr = strings.ToLower(expr) + if idx := strings.IndexByte(expr, '('); idx != -1 { + return expr[:idx] + } + return "" } // implement the table view interface. @@ -3065,6 +3345,10 @@ func (w *WithBuilder) C(column string) string { // Query returns query representation of a `WITH` clause. func (w *WithBuilder) Query() (string, []any) { + if w.ydb() { + return w.queryYDB() + } + w.WriteString("WITH ") if w.recursive { w.WriteString("RECURSIVE ") @@ -3087,6 +3371,55 @@ func (w *WithBuilder) Query() (string, []any) { return w.String(), w.args } +// YDB uses named expressions ($name = SELECT ...) instead of CTEs +func (w *WithBuilder) queryYDB() (string, []any) { + // Collect all CTE names for marking references + cteNames := make(map[string]struct{}) + for _, cte := range w.ctes { + cteNames[cte.name] = struct{}{} + } + + for i, cte := range w.ctes { + if i > 0 { + w.WriteString(" ") + } + + // Mark CTE references in the selector's FROM and JOIN tables + if cte.s != nil { + w.markCteReferencesInSelector(cte.s, cteNames) + } + + w.WriteString("$") + w.Ident(cte.name) + w.WriteString(" = ") + + w.Wrap(func(b *Builder) { + b.Join(cte.s) + }) + + w.WriteString(";") + } + return w.String(), w.args +} + +// markCteReferencesInSelector marks tables in a selector that reference CTEs. +func (w *WithBuilder) markCteReferencesInSelector(s *Selector, cteNames map[string]struct{}) { + for _, from := range s.from { + if table, ok := from.(*SelectTable); ok { + if _, isCte := cteNames[table.name]; isCte { + table.isCte = true + } + } + } + for _, join := range s.joins { + if table, ok := join.table.(*SelectTable); ok { + if _, isCte := cteNames[table.name]; isCte { + table.isCte = true + } + } + } +} + // implement the table view interface. func (*WithBuilder) view() {} @@ -3526,8 +3859,8 @@ func (b *Builder) Args(a ...any) *Builder { // // FormatArg("JSON(?)", b). // FormatArg("ST_GeomFromText(?)", geom) -func (b *Builder) Argf(format string, a any) *Builder { - switch a := a.(type) { +func (b *Builder) Argf(format string, arg any) *Builder { + switch a := arg.(type) { case nil: b.WriteString("NULL") return b @@ -3540,19 +3873,78 @@ func (b *Builder) Argf(format string, a any) *Builder { } b.total++ - // YDB requires named parameters + // YDB requires named parameters with $paramName syntax. if b.ydb() { - // Extract parameter name from format (e.g., "$p0" -> "p0") paramName := strings.TrimPrefix(format, "$") - b.args = append(b.args, driver.NamedValue{Name: paramName, Value: a}) + b.args = append(b.args, driver.NamedValue{ + Name: paramName, + Value: b.convertValueYdb(arg), + }) } else { - b.args = append(b.args, a) + b.args = append(b.args, arg) } b.WriteString(format) return b } +// YDB has strong typing system +// and YDB driver can't convert type aliases to underlying Go type +// Therefore, we have to manually handle these edge cases +func (b *Builder) convertValueYdb(arg any) any { + finalValue := arg + + switch casted := arg.(type) { + case uuid.UUID: + finalValue = casted + case driver.Valuer: + if v, err := casted.Value(); err == nil { + finalValue = v + } + default: + // YDB requires exact numeric types. + // Convert named types to their base primitive type + // while preserving the exact numeric size. + typ := reflect.TypeOf(arg) + value := reflect.ValueOf(arg) + + switch typ.Kind() { + case reflect.Int: + finalValue = int(value.Int()) + case reflect.Int8: + finalValue = int8(value.Int()) // #nosec G115 + case reflect.Int16: + finalValue = int16(value.Int()) // #nosec G115 + case reflect.Int32: + finalValue = int32(value.Int()) // #nosec G115 + case reflect.Int64: + finalValue = value.Int() + case reflect.Uint: + finalValue = uint(value.Uint()) + case reflect.Uint8: + finalValue = uint8(value.Uint()) // #nosec G115 + case reflect.Uint16: + finalValue = uint16(value.Uint()) // #nosec G115 + case reflect.Uint32: + finalValue = uint32(value.Uint()) // #nosec G115 + case reflect.Uint64: + finalValue = value.Uint() + case reflect.Float32: + finalValue = float32(value.Float()) + case reflect.Float64: + finalValue = value.Float() + default: + // Convert other custom types (e.g., http.Dir -> string) + converted, err := driver.DefaultParameterConverter.ConvertValue(arg) + if err == nil { + finalValue = converted + } + } + } + + return finalValue +} + // Comma adds a comma to the query. func (b *Builder) Comma() *Builder { return b.WriteString(", ") @@ -3726,6 +4118,11 @@ func Dialect(name string) *DialectBuilder { return &DialectBuilder{name} } +// Dialect returns the dialect name of this builder. +func (d *DialectBuilder) Dialect() string { + return d.dialect +} + // String builds a dialect-aware expression string from the given callback. func (d *DialectBuilder) String(f func(*Builder)) string { b := &Builder{} diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index 028b6d10ee..a3ff15fc98 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -746,7 +746,7 @@ func TestBuilder(t *testing.T) { Join(t2). On(t1.C("id"), t2.C("user_id")) }(), - wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id`", + wantQuery: "SELECT `u`.`id` AS `id`, `g`.`name` AS `name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id`", }, { input: func() Querier { @@ -1131,7 +1131,7 @@ func TestBuilder(t *testing.T) { { input: Dialect(dialect.YDB). Select().Count().From(Table("users")), - wantQuery: "SELECT COUNT(*) FROM `users`", + wantQuery: "SELECT COUNT(*) AS `count` FROM `users`", }, { input: Select().Count(Distinct("id")).From(Table("users")), @@ -1205,7 +1205,7 @@ func TestBuilder(t *testing.T) { Select("name", Count("*")). From(Table("users")). GroupBy("name"), - wantQuery: "SELECT `name`, COUNT(*) FROM `users` GROUP BY `name`", + wantQuery: "SELECT `name`, COUNT(*) AS `count` FROM `users` GROUP BY `name`", }, { input: Select("name", Count("*")). @@ -1831,7 +1831,7 @@ AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", "" Join(t2). On(t1.C("user_id"), t2.C("id")) }(), - wantQuery: "SELECT `orders`.`id`, `u`.`name` FROM `orders` JOIN `users` VIEW `idx_email` AS `u` ON `orders`.`user_id` = `u`.`id`", + wantQuery: "SELECT `orders`.`id` AS `id`, `u`.`name` AS `name` FROM `orders` JOIN `users` VIEW `idx_email` AS `u` ON `orders`.`user_id` = `u`.`id`", }, { input: func() Querier { @@ -1843,7 +1843,7 @@ AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", "" Join(t2). On(t1.C("ref"), t2.C("ref")) }(), - wantQuery: "SELECT `a`.`value`, `b`.`value` FROM `a_table` AS `a` JOIN `b_table` VIEW `b_index_ref` AS `b` ON `a`.`ref` = `b`.`ref`", + wantQuery: "SELECT `a`.`value` AS `value`, `b`.`value` AS `value` FROM `a_table` AS `a` JOIN `b_table` VIEW `b_index_ref` AS `b` ON `a`.`ref` = `b`.`ref`", }, { input: func() Querier { @@ -1869,7 +1869,7 @@ AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", "" LeftSemiJoin(t2). On(t1.C("id"), t2.C("user_id")) }(), - wantQuery: "SELECT `users`.`id`, `users`.`name` FROM `users` LEFT SEMI JOIN `blacklist` AS `t1` ON `users`.`id` = `t1`.`user_id`", + wantQuery: "SELECT `users`.`id` AS `id`, `users`.`name` AS `name` FROM `users` LEFT SEMI JOIN `blacklist` AS `t1` ON `users`.`id` = `t1`.`user_id`", }, { input: func() Querier { @@ -1881,7 +1881,7 @@ AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", "" RightSemiJoin(t2). On(t1.C("user_id"), t2.C("id")) }(), - wantQuery: "SELECT `t1`.`id`, `t1`.`email` FROM `orders` RIGHT SEMI JOIN `active_users` AS `t1` ON `orders`.`user_id` = `t1`.`id`", + wantQuery: "SELECT `t1`.`id` AS `id`, `t1`.`email` AS `email` FROM `orders` RIGHT SEMI JOIN `active_users` AS `t1` ON `orders`.`user_id` = `t1`.`id`", }, { input: func() Querier { @@ -1893,7 +1893,7 @@ AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", "" LeftOnlyJoin(t2). On(t1.C("id"), t2.C("id")) }(), - wantQuery: "SELECT `users`.`id`, `users`.`name` FROM `users` LEFT ONLY JOIN `deleted_users` AS `t1` ON `users`.`id` = `t1`.`id`", + wantQuery: "SELECT `users`.`id` AS `id`, `users`.`name` AS `name` FROM `users` LEFT ONLY JOIN `deleted_users` AS `t1` ON `users`.`id` = `t1`.`id`", }, { input: func() Querier { @@ -1905,7 +1905,7 @@ AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", "" RightOnlyJoin(t2). On(t1.C("user_id"), t2.C("id")) }(), - wantQuery: "SELECT `t1`.`id`, `t1`.`status` FROM `archived` RIGHT ONLY JOIN `users` AS `t1` ON `archived`.`user_id` = `t1`.`id`", + wantQuery: "SELECT `t1`.`id` AS `id`, `t1`.`status` AS `status` FROM `archived` RIGHT ONLY JOIN `users` AS `t1` ON `archived`.`user_id` = `t1`.`id`", }, { input: func() Querier { @@ -1917,7 +1917,7 @@ AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", "" ExclusionJoin(t2). On(t1.C("key"), t2.C("key")) }(), - wantQuery: "SELECT `a`.`key`, `b`.`key` FROM `table_a` AS `a` EXCLUSION JOIN `table_b` AS `b` ON `a`.`key` = `b`.`key`", + wantQuery: "SELECT `a`.`key` AS `key`, `b`.`key` AS `key` FROM `table_a` AS `a` EXCLUSION JOIN `table_b` AS `b` ON `a`.`key` = `b`.`key`", }, { input: func() Querier { @@ -1941,7 +1941,7 @@ AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", "" LeftJoin(t3). OnP(And(EQ(t3.C("ref"), Expr(t1.C("key"))), EQ(t3.C("column1"), Expr(t2.C("value"))))) }(), - wantQuery: "SELECT `a`.`value`, `b`.`value`, `c`.`column2` FROM `a_table` AS `a` CROSS JOIN `b_table` AS `b` LEFT JOIN `c_table` AS `c` ON `c`.`ref` = `a`.`key` AND `c`.`column1` = `b`.`value`", + wantQuery: "SELECT `a`.`value` AS `value`, `b`.`value` AS `value`, `c`.`column2` AS `column2` FROM `a_table` AS `a` CROSS JOIN `b_table` AS `b` LEFT JOIN `c_table` AS `c` ON `c`.`ref` = `a`.`key` AND `c`.`column1` = `b`.`value`", }, { input: Dialect(dialect.YDB). @@ -2323,7 +2323,7 @@ func TestSelector_VIEW_SecondaryIndex_YDB(t *testing.T) { Where(EQ(users.C("name"), "John Doe")). Query() - require.Equal(t, "SELECT `t1`.`series_id`, `t1`.`title` FROM `series` VIEW `users_index` AS `t1` JOIN `users` VIEW `name_index` AS `t2` ON `t1`.`uploaded_user_id` = `t2`.`user_id` WHERE `t2`.`name` = $p0", query) + require.Equal(t, "SELECT `t1`.`series_id` AS `series_id`, `t1`.`title` AS `title` FROM `series` VIEW `users_index` AS `t1` JOIN `users` VIEW `name_index` AS `t2` ON `t1`.`uploaded_user_id` = `t2`.`user_id` WHERE `t2`.`name` = $p0", query) require.Equal(t, []any{driver.NamedValue{Name: "p0", Value: "John Doe"}}, args) }) @@ -2344,7 +2344,7 @@ func TestBatchUpdate_YDB(t *testing.T) { Set("Value2", 0). Where(GT("Key1", 1)). Query() - + require.Equal(t, "BATCH UPDATE `my_table` SET `Value1` = $p0, `Value2` = $p1 WHERE `Key1` > $p2", query) require.Equal(t, []any{ driver.NamedValue{Name: "p0", Value: "foo"}, @@ -2358,9 +2358,9 @@ func TestBatchUpdate_YDB(t *testing.T) { BatchUpdate("users"). Set("status", "active"). Where(GT("created_at", "2024-01-01")) - + query, args, err := builder.QueryErr() - + require.Empty(t, query) require.Empty(t, args) require.Error(t, err) @@ -2371,9 +2371,9 @@ func TestBatchUpdate_YDB(t *testing.T) { BatchUpdate("users"). Set("status", "active"). Returning("id", "status") - + query, args, err := builder.QueryErr() - + require.Empty(t, query) require.Empty(t, args) require.Error(t, err) @@ -2382,13 +2382,13 @@ func TestBatchUpdate_YDB(t *testing.T) { t.Run("BATCH UPDATE with UPDATE ON pattern should error", func(t *testing.T) { d := Dialect(dialect.YDB) subquery := d.Select("id").From(Table("orders")).Where(EQ("status", "pending")) - + builder := d.BatchUpdate("users"). Set("status", "active"). On(subquery) - + query, args, err := builder.QueryErr() - + require.Empty(t, query) require.Empty(t, args) require.Error(t, err) @@ -2402,7 +2402,7 @@ func TestBatchDelete_YDB(t *testing.T) { query, args := d.BatchDelete("my_table"). Where(And(GT("Key1", 1), GTE("Key2", "One"))). Query() - + require.Equal(t, "BATCH DELETE FROM `my_table` WHERE `Key1` > $p0 AND `Key2` >= $p1", query) require.Equal(t, []any{ driver.NamedValue{Name: "p0", Value: 1}, @@ -2414,9 +2414,9 @@ func TestBatchDelete_YDB(t *testing.T) { builder := Dialect(dialect.MySQL). BatchDelete("users"). Where(GT("id", 100)) - + query, args, err := builder.QueryErr() - + require.Empty(t, query) require.Empty(t, args) require.Error(t, err) @@ -2427,9 +2427,9 @@ func TestBatchDelete_YDB(t *testing.T) { BatchDelete("users"). Where(GT("id", 100)). Returning("id") - + query, args, err := builder.QueryErr() - + require.Empty(t, query) require.Empty(t, args) require.Error(t, err) @@ -2438,12 +2438,12 @@ func TestBatchDelete_YDB(t *testing.T) { t.Run("BATCH DELETE with DELETE ON pattern should error", func(t *testing.T) { d := Dialect(dialect.YDB) subquery := d.Select("id").From(Table("users")).Where(EQ("status", "deleted")) - + builder := d.BatchDelete("users"). On(subquery) - + query, args, err := builder.QueryErr() - + require.Empty(t, query) require.Empty(t, args) require.Error(t, err) diff --git a/dialect/sql/schema/atlas.go b/dialect/sql/schema/atlas.go index d0ce2ff502..c4145161b7 100644 --- a/dialect/sql/schema/atlas.go +++ b/dialect/sql/schema/atlas.go @@ -1103,6 +1103,11 @@ func (a *Atlas) aIndexes(et *Table, at *schema.Table) error { if err := a.sqlDialect.atIndex(idx1, at, idx2); err != nil { return err } + + if len(idx2.Parts) == 0 { + continue + } + desc := descIndexes(idx1) for _, p := range idx2.Parts { p.Desc = desc[p.C.Name] diff --git a/dialect/sql/schema/ydb.go b/dialect/sql/schema/ydb.go index c37218ffc9..df835b415c 100644 --- a/dialect/sql/schema/ydb.go +++ b/dialect/sql/schema/ydb.go @@ -6,7 +6,6 @@ package schema import ( "context" - "errors" "fmt" "strings" @@ -68,22 +67,12 @@ func (d *YDB) tableExist(ctx context.Context, conn dialect.ExecQuerier, name str // atOpen returns a custom Atlas migrate.Driver for YDB. func (d *YDB) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { - var ydbDriver *entdrv.YDBDriver - - switch drv := conn.(type) { - case *entdrv.YDBDriver: - ydbDriver = drv - case *YDB: - if ydb, ok := drv.Driver.(*entdrv.YDBDriver); ok { - ydbDriver = ydb - } + ydbDriver := unwrapYDBDriver(conn) + if ydbDriver == nil { + ydbDriver = unwrapYDBDriver(d.Driver) } if ydbDriver == nil { - if ydb, ok := d.Driver.(*entdrv.YDBDriver); ok { - ydbDriver = ydb - } else { - return nil, fmt.Errorf("expected dialect/ydb.YDBDriver, but got %T", conn) - } + return nil, fmt.Errorf("expected dialect/ydb.YDBDriver, but got %T", conn) } return atlas.Open( @@ -92,6 +81,20 @@ func (d *YDB) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { ) } +func unwrapYDBDriver(driver any) *entdrv.YDBDriver { + switch drv := driver.(type) { + case *entdrv.YDBDriver: + return drv + case *YDB: + return unwrapYDBDriver(drv.Driver) + case *WriteDriver: + return unwrapYDBDriver(drv.Driver) + case *dialect.DebugDriver: + return unwrapYDBDriver(drv.Driver) + } + return nil +} + func (d *YDB) atTable(table1 *Table, table2 *schema.Table) { if table1.Annotation != nil { setAtChecks(table1, table2) @@ -163,7 +166,9 @@ func (d *YDB) atTypeC(column1 *Column, column2 *schema.Column) error { case field.TypeUUID: typ = &schema.UUIDType{T: atlas.TypeUUID} case field.TypeEnum: - err = errors.New("ydb: Enum can't be used as column data type for tables") + // YDB doesn't support enum types in DDL statements + // But ent can handle enum validation, so we just map it to Utf8 + typ = &schema.StringType{T: atlas.TypeUtf8} case field.TypeOther: typ = &schema.UnsupportedType{T: column1.typ} default: @@ -179,13 +184,17 @@ func (d *YDB) atTypeC(column1 *Column, column2 *schema.Column) error { } // atUniqueC adds a unique constraint for a column. -// In YDB, unique constraints are implemented as GLOBAL UNIQUE SYNC indexes. +// In YDB, unique constraints are implemented as GLOBAL UNIQUE indexes. func (d *YDB) atUniqueC( table1 *Table, column1 *Column, table2 *schema.Table, column2 *schema.Column, ) { + if !canBeIndexKey(column1) { + return + } + // Check if there's already an explicit unique index defined for this column. for _, idx := range table1.Indexes { if idx.Unique && len(idx.Columns) == 1 && idx.Columns[0].Name == column1.Name { @@ -194,11 +203,8 @@ func (d *YDB) atUniqueC( } } // Create a unique index for this column. - idxName := fmt.Sprintf("%s_%s_index", table1.Name, column1.Name) - index := schema.NewUniqueIndex(idxName).AddColumns(column2) - - // Add YDB-specific attribute for GLOBAL SYNC index type. - index.AddAttrs(&atlas.IndexAttributes{Global: true, Sync: true}) + indexName := fmt.Sprintf("%s_%s_uniq_idx", table1.Name, column1.Name) + index := schema.NewUniqueIndex(indexName).AddParts(&schema.IndexPart{C: column2}) table2.AddIndexes(index) } @@ -226,34 +232,65 @@ func (d *YDB) atIndex( table2 *schema.Table, index2 *schema.Index, ) error { + indexColumns := make([]string, 0) for _, column1 := range index1.Columns { + if isPrimaryKeyColumn(table2, column1.Name) { + continue + } + column2, ok := table2.Column(column1.Name) if !ok { return fmt.Errorf("unexpected index %q column: %q", index1.Name, column1.Name) } + + if !canBeIndexKeyBySchema(column2) { + continue + } + index2.AddParts(&schema.IndexPart{C: column2}) + indexColumns = append(indexColumns, column2.Name) } // Set YDB-specific index attributes. // By default, use GLOBAL SYNC for consistency. - idxType := &atlas.IndexAttributes{Global: true, Sync: true} + idxAttrs := &atlas.IndexAttributes{} - // Check for annotation overrides. if index1.Annotation != nil { + annotation := index1.Annotation + + if len(annotation.IncludeColumns) > 0 { + columns := make([]*schema.Column, len(annotation.IncludeColumns)) + + for i, include := range annotation.IncludeColumns { + column, ok := table2.Column(include) + if !ok { + return fmt.Errorf("include column %q was not found for index %q", include, index1.Name) + } + columns[i] = column + } + + idxAttrs.CoverColumns = columns + } + if indexType, ok := indexType(index1, dialect.YDB); ok { - // Parse YDB-specific index type from annotation. - switch strings.ToUpper(indexType) { - case "GLOBAL ASYNC", "ASYNC": - idxType.Sync = false - case "LOCAL": - idxType.Global = false - case "LOCAL ASYNC": - idxType.Global = false - idxType.Sync = false + upperIndexType := strings.ToUpper(indexType) + + if strings.Contains(upperIndexType, "ASYNC") { + idxAttrs.Async = true + } + if strings.Contains(upperIndexType, "UNIQUE") { + index2.Unique = true } } } - index2.AddAttrs(idxType) + + index2.Name = fmt.Sprintf( + "%s_%s_idx", + table2.Name, + strings.Join(indexColumns, "_"), + ) + + index2.AddAttrs(idxAttrs) return nil } @@ -269,3 +306,41 @@ func (*YDB) atTypeRangeSQL(ts ...string) string { strings.Join(values, ", "), ) } + +// canBeIndexKey checks if a column type can be used as an index key in YDB. +// YDB doesn't allow Float/Double types as index keys. +func canBeIndexKey(column *Column) bool { + switch column.Type { + case field.TypeFloat32, field.TypeFloat64: + return false + default: + return true + } +} + +// canBeIndexKeyBySchema checks if a column type can be used as an index key in YDB +// by checking the Atlas schema column type. +func canBeIndexKeyBySchema(column *schema.Column) bool { + if column.Type == nil || column.Type.Type == nil { + return true + } + switch column.Type.Type.(type) { + case *schema.FloatType: + return false + default: + return true + } +} + +// isPrimaryKeyColumn checks if a column is part of the table's primary key. +func isPrimaryKeyColumn(table *schema.Table, columnName string) bool { + if table.PrimaryKey == nil { + return false + } + for _, primaryKeyPart := range table.PrimaryKey.Parts { + if primaryKeyPart.C != nil && primaryKeyPart.C.Name == columnName { + return true + } + } + return false +} diff --git a/dialect/sql/sqlgraph/errors.go b/dialect/sql/sqlgraph/errors.go index 863cdc003f..fb63bf0ed0 100644 --- a/dialect/sql/sqlgraph/errors.go +++ b/dialect/sql/sqlgraph/errors.go @@ -28,6 +28,7 @@ func IsUniqueConstraintError(err error) bool { "Error 1062", // MySQL "violates unique constraint", // Postgres "UNIQUE constraint failed", // SQLite + "PRECONDITION_FAILED", // YDB unique index violation } { if strings.Contains(err.Error(), s) { return true diff --git a/dialect/sql/sqlgraph/graph.go b/dialect/sql/sqlgraph/graph.go index 6cb9d04a1c..250867c5a4 100644 --- a/dialect/sql/sqlgraph/graph.go +++ b/dialect/sql/sqlgraph/graph.go @@ -162,124 +162,146 @@ func (s *Step) ThroughEdgeTable() bool { // Neighbors returns a Selector for evaluating the path-step // and getting the neighbors of one vertex. -func Neighbors(dialect string, s *Step) (q *sql.Selector) { +func Neighbors(dialect string, step *Step) (query *sql.Selector) { builder := sql.Dialect(dialect) + switch { - case s.ThroughEdgeTable(): - pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] - if s.Edge.Inverse { - pk1, pk2 = pk2, pk1 + case step.ThroughEdgeTable(): + pk1, pk2 := step.Edge.Columns[0], step.Edge.Columns[1] + if step.Edge.Inverse { + pk2, pk1 = pk1, pk2 } - to := builder.Table(s.To.Table).Schema(s.To.Schema) - join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) - match := builder.Select(join.C(pk1)). + + to := builder.Table(step.To.Table).Schema(step.To.Schema) + join := builder.Table(step.Edge.Table).Schema(step.Edge.Schema) + + match := builder.Select(join.C(pk2)). From(join). - Where(sql.EQ(join.C(pk2), s.From.V)) - q = builder.Select(). + Where(sql.EQ(join.C(pk1), step.From.V)) + + query = builder.Select(). From(to). Join(match). - On(to.C(s.To.Column), match.C(pk1)) - case s.FromEdgeOwner(): - t1 := builder.Table(s.To.Table).Schema(s.To.Schema) - t2 := builder.Select(s.Edge.Columns[0]). - From(builder.Table(s.Edge.Table).Schema(s.Edge.Schema)). - Where(sql.EQ(s.From.Column, s.From.V)) - q = builder.Select(). - From(t1). - Join(t2). - On(t1.C(s.To.Column), t2.C(s.Edge.Columns[0])) - case s.ToEdgeOwner(): - q = builder.Select(). - From(builder.Table(s.To.Table).Schema(s.To.Schema)). - Where(sql.EQ(s.Edge.Columns[0], s.From.V)) + On(to.C(step.To.Column), match.C(pk2)) + + case step.FromEdgeOwner(): + table1 := builder.Table(step.To.Table).Schema(step.To.Schema) + + table2 := builder.Select(step.Edge.Columns[0]). + From(builder.Table(step.Edge.Table).Schema(step.Edge.Schema)). + Where(sql.EQ(step.From.Column, step.From.V)) + + query = builder.Select(). + From(table1). + Join(table2). + On(table1.C(step.To.Column), table2.C(step.Edge.Columns[0])) + + case step.ToEdgeOwner(): + query = builder.Select(). + From(builder.Table(step.To.Table).Schema(step.To.Schema)). + Where(sql.EQ(step.Edge.Columns[0], step.From.V)) } - return q + return query } // SetNeighbors returns a Selector for evaluating the path-step // and getting the neighbors of set of vertices. -func SetNeighbors(dialect string, s *Step) (q *sql.Selector) { - set := s.From.V.(*sql.Selector) +func SetNeighbors(dialect string, step *Step) (query *sql.Selector) { + set := step.From.V.(*sql.Selector) builder := sql.Dialect(dialect) + switch { - case s.ThroughEdgeTable(): - pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] - if s.Edge.Inverse { - pk1, pk2 = pk2, pk1 + case step.ThroughEdgeTable(): + pk1, pk2 := step.Edge.Columns[0], step.Edge.Columns[1] + if step.Edge.Inverse { + pk2, pk1 = pk1, pk2 } - to := builder.Table(s.To.Table).Schema(s.To.Schema) - set.Select(set.C(s.From.Column)) - join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) - match := builder.Select(join.C(pk1)). + + to := builder.Table(step.To.Table).Schema(step.To.Schema) + join := builder.Table(step.Edge.Table).Schema(step.Edge.Schema) + + set.Select(set.C(step.From.Column)) + + match := builder.Select(join.C(pk2)). From(join). Join(set). - On(join.C(pk2), set.C(s.From.Column)) - q = builder.Select(). + On(join.C(pk1), set.C(step.From.Column)) + + query = builder.Select(). From(to). Join(match). - On(to.C(s.To.Column), match.C(pk1)) - case s.FromEdgeOwner(): - t1 := builder.Table(s.To.Table).Schema(s.To.Schema) - set.Select(set.C(s.Edge.Columns[0])) - q = builder.Select(). - From(t1). + On(to.C(step.To.Column), match.C(pk2)) + + case step.FromEdgeOwner(): + table1 := builder.Table(step.To.Table).Schema(step.To.Schema) + set.Select(set.C(step.Edge.Columns[0])) + + query = builder.Select(). + From(table1). Join(set). - On(t1.C(s.To.Column), set.C(s.Edge.Columns[0])) - case s.ToEdgeOwner(): - t1 := builder.Table(s.To.Table).Schema(s.To.Schema) - set.Select(set.C(s.From.Column)) - q = builder.Select(). - From(t1). + On(table1.C(step.To.Column), set.C(step.Edge.Columns[0])) + + case step.ToEdgeOwner(): + table1 := builder.Table(step.To.Table).Schema(step.To.Schema) + set.Select(set.C(step.From.Column)) + + query = builder.Select(). + From(table1). Join(set). - On(t1.C(s.Edge.Columns[0]), set.C(s.From.Column)) + On(table1.C(step.Edge.Columns[0]), set.C(step.From.Column)) } - return q + return query } // HasNeighbors applies on the given Selector a neighbors check. -func HasNeighbors(q *sql.Selector, s *Step) { - builder := sql.Dialect(q.Dialect()) +func HasNeighbors(query *sql.Selector, step *Step) { + builder := sql.Dialect(query.Dialect()) + switch { - case s.ThroughEdgeTable(): - pk1 := s.Edge.Columns[0] - if s.Edge.Inverse { - pk1 = s.Edge.Columns[1] + case step.ThroughEdgeTable(): + pk1 := step.Edge.Columns[0] + if step.Edge.Inverse { + pk1 = step.Edge.Columns[1] } - join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) - q.Where( + + join := builder.Table(step.Edge.Table).Schema(step.Edge.Schema) + query.Where( sql.In( - q.C(s.From.Column), + query.C(step.From.Column), builder.Select(join.C(pk1)).From(join), ), ) - case s.FromEdgeOwner(): - q.Where(sql.NotNull(q.C(s.Edge.Columns[0]))) - case s.ToEdgeOwner(): - to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) + + case step.FromEdgeOwner(): + query.Where(sql.NotNull(query.C(step.Edge.Columns[0]))) + + case step.ToEdgeOwner(): + to := builder.Table(step.Edge.Table).Schema(step.Edge.Schema) // In case the edge reside on the same table, give // the edge an alias to make qualifier different. - if s.From.Table == s.Edge.Table { - to.As(fmt.Sprintf("%s_edge", s.Edge.Table)) - } - - // YDB doesn't support correlated EXISTS subqueries. - // Use IN subquery instead for YDB dialect. - if q.Dialect() == dialect.YDB { - q.Where( + if step.From.Table == step.Edge.Table { + to.As(fmt.Sprintf("%s_edge", step.Edge.Table)) + } + + if query.Dialect() == dialect.YDB { + // YDB doesn't support correlated subqueries, use IN with subquery instead. + query.Where( sql.In( - q.C(s.From.Column), - builder.Select(to.C(s.Edge.Columns[0])).From(to), + query.C(step.From.Column), + builder.Select(to.C(step.Edge.Columns[0])). + From(to). + Where(sql.NotNull(to.C(step.Edge.Columns[0]))), ), ) } else { - q.Where( + query.Where( sql.Exists( - builder.Select(to.C(s.Edge.Columns[0])). + builder.Select(to.C(step.Edge.Columns[0])). From(to). Where( sql.ColumnsEQ( - q.C(s.From.Column), - to.C(s.Edge.Columns[0]), + query.C(step.From.Column), + to.C(step.Edge.Columns[0]), ), ), ), @@ -290,98 +312,105 @@ func HasNeighbors(q *sql.Selector, s *Step) { // HasNeighborsWith applies on the given Selector a neighbors check. // The given predicate applies its filtering on the selector. -func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) { - builder := sql.Dialect(q.Dialect()) +func HasNeighborsWith( + query *sql.Selector, + step *Step, + predicate func(*sql.Selector), +) { + builder := sql.Dialect(query.Dialect()) + switch { - case s.ThroughEdgeTable(): - pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] - if s.Edge.Inverse { - pk1, pk2 = pk2, pk1 + case step.ThroughEdgeTable(): + pk1, pk2 := step.Edge.Columns[0], step.Edge.Columns[1] + if step.Edge.Inverse { + pk2, pk1 = pk1, pk2 } - to := builder.Table(s.To.Table).Schema(s.To.Schema) - edge := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) - join := builder.Select(edge.C(pk2)). + + to := builder.Table(step.To.Table).Schema(step.To.Schema) + edge := builder.Table(step.Edge.Table).Schema(step.Edge.Schema) + + join := builder.Select(edge.C(pk1)). From(edge). Join(to). - On(edge.C(pk1), to.C(s.To.Column)) + On(edge.C(pk2), to.C(step.To.Column)) + matches := builder.Select().From(to) - matches.WithContext(q.Context()) - pred(matches) + matches.WithContext(query.Context()) + predicate(matches) join.FromSelect(matches) - q.Where(sql.In(q.C(s.From.Column), join)) - case s.FromEdgeOwner(): - to := builder.Table(s.To.Table).Schema(s.To.Schema) + query.Where(sql.In(query.C(step.From.Column), join)) + + case step.FromEdgeOwner(): + to := builder.Table(step.To.Table).Schema(step.To.Schema) // Avoid ambiguity in case both source // and edge tables are the same. - if s.To.Table == q.TableName() { - to.As(fmt.Sprintf("%s_edge", s.To.Table)) + if step.To.Table == query.TableName() { + to.As(fmt.Sprintf("%s_edge", step.To.Table)) // Choose the alias name until we do not // have a collision. Limit to 5 iterations. for i := 1; i <= 5; i++ { - if to.C("c") != q.C("c") { + if to.C("c") != query.C("c") { break } - to.As(fmt.Sprintf("%s_edge_%d", s.To.Table, i)) + to.As(fmt.Sprintf("%s_edge_%d", step.To.Table, i)) } } - // YDB doesn't support correlated EXISTS subqueries. - // Use IN subquery instead for YDB dialect. - if q.Dialect() == dialect.YDB { - matches := builder.Select(to.C(s.To.Column)).From(to) - matches.WithContext(q.Context()) - pred(matches) - q.Where(sql.In(q.C(s.Edge.Columns[0]), matches)) + if query.Dialect() == dialect.YDB { + // YDB doesn't support correlated subqueries, use IN with subquery instead + matches := builder.Select(to.C(step.To.Column)).From(to) + matches.WithContext(query.Context()) + predicate(matches) + query.Where(sql.In(query.C(step.Edge.Columns[0]), matches)) } else { - matches := builder.Select(to.C(s.To.Column)). + matches := builder.Select(to.C(step.To.Column)). From(to) - matches.WithContext(q.Context()) + matches.WithContext(query.Context()) matches.Where( sql.ColumnsEQ( - q.C(s.Edge.Columns[0]), - to.C(s.To.Column), + query.C(step.Edge.Columns[0]), + to.C(step.To.Column), ), ) - pred(matches) - q.Where(sql.Exists(matches)) + predicate(matches) + query.Where(sql.Exists(matches)) } - case s.ToEdgeOwner(): - to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) + case step.ToEdgeOwner(): + to := builder.Table(step.Edge.Table).Schema(step.Edge.Schema) // Avoid ambiguity in case both source // and edge tables are the same. - if s.Edge.Table == q.TableName() { - to.As(fmt.Sprintf("%s_edge", s.Edge.Table)) + if step.Edge.Table == query.TableName() { + to.As(fmt.Sprintf("%s_edge", step.Edge.Table)) // Choose the alias name until we do not // have a collision. Limit to 5 iterations. for i := 1; i <= 5; i++ { - if to.C("c") != q.C("c") { + if to.C("c") != query.C("c") { break } - to.As(fmt.Sprintf("%s_edge_%d", s.Edge.Table, i)) + to.As(fmt.Sprintf("%s_edge_%d", step.Edge.Table, i)) } } - - // YDB doesn't support correlated EXISTS subqueries. - // Use IN subquery instead for YDB dialect. - if q.Dialect() == dialect.YDB { - matches := builder.Select(to.C(s.Edge.Columns[0])).From(to) - matches.WithContext(q.Context()) - pred(matches) - q.Where(sql.In(q.C(s.From.Column), matches)) + + if query.Dialect() == dialect.YDB { + // YDB doesn't support correlated subqueries, using IN with subquery instead + matches := builder.Select(to.C(step.Edge.Columns[0])).From(to) + matches.WithContext(query.Context()) + predicate(matches) + query.Where(sql.In(query.C(step.From.Column), matches)) } else { - matches := builder.Select(to.C(s.Edge.Columns[0])). + matches := builder.Select(to.C(step.Edge.Columns[0])). From(to) - matches.WithContext(q.Context()) + matches.WithContext(query.Context()) matches.Where( sql.ColumnsEQ( - q.C(s.From.Column), - to.C(s.Edge.Columns[0]), + query.C(step.From.Column), + to.C(step.Edge.Columns[0]), ), ) - pred(matches) - q.Where(sql.Exists(matches)) + predicate(matches) + query.Where(sql.Exists(matches)) } } } @@ -573,11 +602,12 @@ func OrderByNeighborTerms(q *sql.Selector, s *Step, opts ...sql.OrderTerm) { } toT := build.Table(s.To.Table).Schema(s.To.Schema) joinT := build.Table(s.Edge.Table).Schema(s.Edge.Schema) - join = build.Select(pk2). + join = build.SelectExpr(). From(toT). Join(joinT). - On(toT.C(s.To.Column), joinT.C(pk1)). - GroupBy(pk2) + On(toT.C(s.To.Column), joinT.C(pk1)) + join.AppendSelect(joinT.C(pk2)). + GroupBy(joinT.C(pk2)) selectTerms(join, opts) q.LeftJoin(join). On(q.C(s.From.Column), join.C(pk2)) @@ -965,17 +995,32 @@ func NewDeleteSpec(table string, id *FieldSpec) *DeleteSpec { func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int, error) { var affected int op := func(ctx context.Context, d dialect.Driver) error { - var ( - res sql.Result - builder = sql.Dialect(drv.Dialect()) - ) + builder := sql.Dialect(drv.Dialect()) selector := builder.Select(). From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema)). WithContext(ctx) if pred := spec.Predicate; pred != nil { pred(selector) } - query, args := builder.Delete(spec.Node.Table).Schema(spec.Node.Schema).FromSelect(selector).Query() + delete := builder.Delete(spec.Node.Table).Schema(spec.Node.Schema).FromSelect(selector) + + // YDB doesn't return accurate RowsAffected(), so we use RETURNING to count deleted rows. + if drv.Dialect() == dialect.YDB { + delete.Returning(spec.Node.ID.Column) + query, args := delete.Query() + rows := &sql.Rows{} + if err := d.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + for rows.Next() { + affected++ + } + return rows.Err() + } + + query, args := delete.Query() + var res sql.Result if err := d.Exec(ctx, query, args, &res); err != nil { return err } @@ -1261,12 +1306,13 @@ func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error { return err } if !update.Empty() { - var res sql.Result - query, args := update.Query() - if err := tx.Exec(ctx, query, args, &res); err != nil { - return err + var returningColumn string + if u.Node.ID != nil { + returningColumn = u.Node.ID.Column + } else { + returningColumn = u.Node.CompositeID[0].Column } - affected, err := res.RowsAffected() + affected, err := execUpdate(ctx, tx, update, returningColumn) if err != nil { return err } @@ -1497,45 +1543,38 @@ func (u *updater) scan(rows *sql.Rows) error { func (u *updater) ensureExists(ctx context.Context) error { selector := u.builder. Select(). - From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)). - Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value)) + From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)) + + var idValue any + if u.Node.ID != nil { + selector.Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value)) + idValue = u.Node.ID.Value + } else { + selector.Where(sql.And( + sql.EQ(u.Node.CompositeID[0].Column, u.Node.CompositeID[0].Value), + sql.EQ(u.Node.CompositeID[1].Column, u.Node.CompositeID[1].Value), + )) + idValue = []any{u.Node.CompositeID[0].Value, u.Node.CompositeID[1].Value} + } u.Predicate(selector) - + var query string var args []any - - // YDB doesn't fully support EXISTS in all contexts. - // Use COUNT(*) > 0 approach instead for better compatibility. - if selector.Dialect() == dialect.YDB { - selector.Count("*") - query, args = selector.Query() - } else { - query, args = u.builder.SelectExpr(sql.Exists(selector)).Query() - } - + + query, args = u.builder.SelectExpr(sql.Exists(selector)).Query() + rows := &sql.Rows{} if err := u.tx.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() - - var found bool - if selector.Dialect() == dialect.YDB { - count, err := sql.ScanInt(rows) - if err != nil { - return err - } - found = count > 0 - } else { - var err error - found, err = sql.ScanBool(rows) - if err != nil { - return err - } + + found, err := sql.ScanBool(rows) + if err != nil { + return err } - if !found { - return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value} + return &NotFoundError{table: u.Node.Table, id: idValue} } return nil } @@ -1548,8 +1587,19 @@ type creator struct { func (c *creator) node(ctx context.Context, drv dialect.Driver) error { var ( edges = EdgeSpecs(c.Edges).GroupRel() - insert = c.builder.Insert(c.Table).Schema(c.Schema).Default() + insert *sql.InsertBuilder ) + + if c.builder.Dialect() == dialect.YDB { + // For YDB: use UPSERT only when OnConflict options are specified, + if len(c.CreateSpec.OnConflict) > 0 { + insert = c.builder.Upsert(c.Table).Schema(c.Schema) + } else { + insert = c.builder.Insert(c.Table).Schema(c.Schema) + } + } else { + insert = c.builder.Insert(c.Table).Schema(c.Schema).Default() + } if err := c.setTableColumns(insert, edges); err != nil { return err } @@ -1623,6 +1673,10 @@ func (c *creator) insert(ctx context.Context, insert *sql.InsertBuilder) error { // ensureConflict ensures the ON CONFLICT is added to the insert statement. func (c *creator) ensureConflict(insert *sql.InsertBuilder) { + // YDB doesn't support ON CONFLICT clause - UPSERT handles conflicts implicitly. + if insert.Dialect() == dialect.YDB { + return + } if opts := c.CreateSpec.OnConflict; len(opts) > 0 { insert.OnConflict(opts...) c.ensureLastInsertID(insert) @@ -1689,7 +1743,19 @@ func (c *batchCreator) nodes(ctx context.Context, drv dialect.Driver) error { } } sorted := keys(columns) - insert := c.builder.Insert(c.Nodes[0].Table).Schema(c.Nodes[0].Schema).Default().Columns(sorted...) + + var insert *sql.InsertBuilder + if c.builder.Dialect() == dialect.YDB { + // For YDB: use UPSERT only when OnConflict options are specified, + if len(c.BatchCreateSpec.OnConflict) > 0 { + insert = c.builder.Upsert(c.Nodes[0].Table).Schema(c.Nodes[0].Schema).Columns(sorted...) + } else { + insert = c.builder.Insert(c.Nodes[0].Table).Schema(c.Nodes[0].Schema).Columns(sorted...) + } + } else { + insert = c.builder.Insert(c.Nodes[0].Table).Schema(c.Nodes[0].Schema).Default().Columns(sorted...) + } + for i := range values { vs := make([]any, len(sorted)) for j, c := range sorted { @@ -1751,6 +1817,10 @@ func (c *batchCreator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, // ensureConflict ensures the ON CONFLICT is added to the insert statement. func (c *batchCreator) ensureConflict(insert *sql.InsertBuilder) { + // YDB doesn't support ON CONFLICT clause - UPSERT handles conflicts implicitly. + if insert.Dialect() == dialect.YDB { + return + } if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 { insert.OnConflict(opts...) } @@ -1854,7 +1924,15 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS values = append(values, f.Value) columns = append(columns, f.Column) } - insert := g.builder.Insert(table).Columns(columns...) + + // YDB doesn't support ON CONFLICT clause. Use UPSERT for M2M edges without extra fields + var insert *sql.InsertBuilder + if len(edges[0].Target.Fields) == 0 && g.builder.Dialect() == dialect.YDB { + insert = g.builder.Upsert(table).Columns(columns...) + } else { + insert = g.builder.Insert(table).Columns(columns...) + } + if edges[0].Schema != "" { // If the Schema field was provided to the EdgeSpec (by the // generated code), it should be the same for all EdgeSpecs. @@ -1902,7 +1980,14 @@ func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error { for _, f := range edge.Target.Fields { columns = append(columns, f.Column) } - insert = g.builder.Insert(name).Columns(columns...) + + // YDB doesn't support ON CONFLICT clause. Use UPSERT for M2M edges without extra fields. + if len(edge.Target.Fields) == 0 && g.builder.Dialect() == dialect.YDB { + insert = g.builder.Upsert(name).Columns(columns...) + } else { + insert = g.builder.Insert(name).Columns(columns...) + } + if edge.Schema != "" { // If the Schema field was provided to the EdgeSpec (by the // generated code), it should be the same for all EdgeSpecs. @@ -1976,19 +2061,17 @@ func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*Edg if len(edge.Target.Nodes) > 1 { p = sql.InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...) } - query, args := g.builder.Update(edge.Table). + + update := g.builder.Update(edge.Table). Schema(edge.Schema). Set(edge.Columns[0], id). - Where(sql.And(p, sql.IsNull(edge.Columns[0]))). - Query() - var res sql.Result - if err := g.tx.Exec(ctx, query, args, &res); err != nil { - return fmt.Errorf("add %s edge for table %s: %w", edge.Rel, edge.Table, err) - } - affected, err := res.RowsAffected() + Where(sql.And(p, sql.IsNull(edge.Columns[0]))) + + affected, err := execUpdate(ctx, g.tx, update, edge.Target.IDSpec.Column) if err != nil { - return err + return fmt.Errorf("add %s edge for table %s: %w", edge.Rel, edge.Table, err) } + // Setting the FK value of the "other" table without clearing it before, is not allowed. // Including no-op (same id), because we rely on "affected" to determine if the FK set. if ids := edge.Target.Nodes; int(affected) < len(ids) { @@ -2015,6 +2098,44 @@ func hasExternalEdges(addEdges, clearEdges map[Rel][]*EdgeSpec) bool { return false } +// execUpdate executes an UPDATE and returns the number of affected rows. +// For YDB, it uses RETURNING clause since RowsAffected() is unreliable. +func execUpdate( + ctx context.Context, + tx dialect.ExecQuerier, + update *sql.UpdateBuilder, + returningColumn string, +) (int64, error) { + if update.Dialect() == dialect.YDB { + update.Returning(returningColumn) + + query, args := update.Query() + rows := &sql.Rows{} + if err := tx.Query(ctx, query, args, rows); err != nil { + return 0, err + } + defer rows.Close() + + var affected int64 + for rows.Next() { + affected++ + } + + if err := rows.Err(); err != nil { + return 0, err + } + + return affected, nil + } else { + query, args := update.Query() + var res sql.Result + if err := tx.Exec(ctx, query, args, &res); err != nil { + return 0, err + } + return res.RowsAffected() + } +} + // isExternalEdge reports if the given edge requires an UPDATE // or an INSERT to other table. func isExternalEdge(e *EdgeSpec) bool { diff --git a/dialect/sql/sqlgraph/graph_test.go b/dialect/sql/sqlgraph/graph_test.go index 13bcf87ae2..64782f6ec8 100644 --- a/dialect/sql/sqlgraph/graph_test.go +++ b/dialect/sql/sqlgraph/graph_test.go @@ -1072,7 +1072,7 @@ func TestOrderByNeighborTerms(t *testing.T) { ) query, args := s.Query() require.Empty(t, args) - require.Equal(t, `SELECT "users"."name", "t1"."total_users" FROM "users" LEFT JOIN (SELECT "user_id", SUM("group"."num_users") AS "total_users" FROM "group" JOIN "user_groups" AS "t1" ON "group"."id" = "t1"."group_id" GROUP BY "user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."total_users" NULLS FIRST`, query) + require.Equal(t, `SELECT "users"."name", "t1"."total_users" FROM "users" LEFT JOIN (SELECT "t1"."user_id", SUM("group"."num_users") AS "total_users" FROM "group" JOIN "user_groups" AS "t1" ON "group"."id" = "t1"."group_id" GROUP BY "t1"."user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."total_users" NULLS FIRST`, query) }) t.Run("M2M/NullsLast", func(t *testing.T) { s := s.Clone() @@ -1090,7 +1090,7 @@ func TestOrderByNeighborTerms(t *testing.T) { ) query, args := s.Query() require.Empty(t, args) - require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "user_id", SUM("group"."num_users") AS "total_users" FROM "group" JOIN "user_groups" AS "t1" ON "group"."id" = "t1"."group_id" GROUP BY "user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."total_users" NULLS LAST`, query) + require.Equal(t, `SELECT "users"."name" FROM "users" LEFT JOIN (SELECT "t1"."user_id", SUM("group"."num_users") AS "total_users" FROM "group" JOIN "user_groups" AS "t1" ON "group"."id" = "t1"."group_id" GROUP BY "t1"."user_id") AS "t1" ON "users"."id" = "t1"."user_id" ORDER BY "t1"."total_users" NULLS LAST`, query) }) } diff --git a/entc/integration/docker-compose.yaml b/entc/integration/docker-compose.yaml index 32ac1af284..4357026d1c 100644 --- a/entc/integration/docker-compose.yaml +++ b/entc/integration/docker-compose.yaml @@ -167,3 +167,21 @@ services: restart: on-failure ports: - 8182:8182 + + ydb: + image: ydbplatform/local-ydb:trunk + platform: linux/amd64 + hostname: localhost + environment: + - GRPC_TLS_PORT=2135 + - GRPC_PORT=2136 + - MON_PORT=8765 + ports: + - 2136:2136 + - 8765:8765 + healthcheck: + test: ["CMD-SHELL", "nc -z localhost 2136"] + interval: 10s + timeout: 5s + retries: 10 + start_period: 30s diff --git a/entc/integration/ent/migrate/schema.go b/entc/integration/ent/migrate/schema.go index 3704811a3a..e0051670d9 100644 --- a/entc/integration/ent/migrate/schema.go +++ b/entc/integration/ent/migrate/schema.go @@ -138,10 +138,10 @@ var ( {Name: "text", Type: field.TypeString, Nullable: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "mediumtext"}}, {Name: "datetime", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime", "postgres": "date"}}, {Name: "decimal", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"mysql": "decimal(6,2)", "postgres": "numeric"}}, - {Name: "link_other", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"mysql": "varchar(255)", "postgres": "varchar", "sqlite3": "varchar(255)"}}, - {Name: "link_other_func", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"mysql": "varchar(255)", "postgres": "varchar", "sqlite3": "varchar(255)"}}, + {Name: "link_other", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"mysql": "varchar(255)", "postgres": "varchar", "sqlite3": "varchar(255)", "ydb": "Utf8"}}, + {Name: "link_other_func", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"mysql": "varchar(255)", "postgres": "varchar", "sqlite3": "varchar(255)", "ydb": "Utf8"}}, {Name: "mac", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "macaddr"}}, - {Name: "string_array", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"mysql": "blob", "postgres": "text[]", "sqlite3": "json"}}, + {Name: "string_array", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"mysql": "blob", "postgres": "text[]", "sqlite3": "json", "ydb": "Utf8"}}, {Name: "password", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"mysql": "char(32)"}}, {Name: "string_scanner", Type: field.TypeString, Nullable: true}, {Name: "duration", Type: field.TypeInt64, Nullable: true}, @@ -174,8 +174,8 @@ var ( {Name: "nil_pair", Type: field.TypeBytes, Nullable: true}, {Name: "vstring", Type: field.TypeString}, {Name: "triple", Type: field.TypeString}, - {Name: "big_int", Type: field.TypeInt, Nullable: true}, - {Name: "password_other", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"mysql": "char(32)", "postgres": "varchar", "sqlite3": "char(32)"}}, + {Name: "big_int", Type: field.TypeInt, Nullable: true, SchemaType: map[string]string{"ydb": "Utf8"}}, + {Name: "password_other", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"mysql": "char(32)", "postgres": "varchar", "sqlite3": "char(32)", "ydb": "Utf8"}}, {Name: "file_field", Type: field.TypeInt, Nullable: true}, } // FieldTypesTable holds the schema information for the "field_types" table. diff --git a/entc/integration/ent/schema/fieldtype.go b/entc/integration/ent/schema/fieldtype.go index 7ca26242d6..44fd5d3189 100644 --- a/entc/integration/ent/schema/fieldtype.go +++ b/entc/integration/ent/schema/fieldtype.go @@ -116,6 +116,7 @@ func (FieldType) Fields() []ent.Field { //nolint:funlen dialect.Postgres: "varchar", dialect.MySQL: "varchar(255)", dialect.SQLite: "varchar(255)", + dialect.YDB: "Utf8", }). Optional(). Default(DefaultLink()), @@ -124,6 +125,7 @@ func (FieldType) Fields() []ent.Field { //nolint:funlen dialect.Postgres: "varchar", dialect.MySQL: "varchar(255)", dialect.SQLite: "varchar(255)", + dialect.YDB: "Utf8", }). Optional(). Default(DefaultLink), @@ -143,6 +145,7 @@ func (FieldType) Fields() []ent.Field { //nolint:funlen dialect.Postgres: "text[]", dialect.SQLite: "json", dialect.MySQL: "blob", + dialect.YDB: "Utf8", }), field.String("password"). Optional(). @@ -289,7 +292,10 @@ func (FieldType) Fields() []ent.Field { //nolint:funlen }), field.Int("big_int"). Optional(). - GoType(BigInt{}), + GoType(BigInt{}). + SchemaType(map[string]string{ + dialect.YDB: "Utf8", + }), field.Other("password_other", Password("")). Optional(). Sensitive(). @@ -297,6 +303,7 @@ func (FieldType) Fields() []ent.Field { //nolint:funlen dialect.MySQL: "char(32)", dialect.SQLite: "char(32)", dialect.Postgres: "varchar", + dialect.YDB: "Utf8", }), } } diff --git a/entc/integration/go.mod b/entc/integration/go.mod index 6ac5a6f8b2..7b1caee556 100644 --- a/entc/integration/go.mod +++ b/entc/integration/go.mod @@ -50,4 +50,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace ariga.io/atlas => github.com/LostImagin4tion/atlas v0.0.18 +replace ariga.io/atlas => github.com/LostImagin4tion/atlas v0.0.33 diff --git a/entc/integration/go.sum b/entc/integration/go.sum index 350fb9eb87..48b6a96115 100644 --- a/entc/integration/go.sum +++ b/entc/integration/go.sum @@ -5,8 +5,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= -github.com/LostImagin4tion/atlas v0.0.18 h1:RrLLU6zEXuRZAih3slblKFZ/lPLUDZ+wrHaTrxILqrA= -github.com/LostImagin4tion/atlas v0.0.18/go.mod h1:Rco1malutATQGeWEoYFzurfzIvs+galayoZ0+Pz4als= +github.com/LostImagin4tion/atlas v0.0.33 h1:RgcQhGG0MZDwheuFRiZu47ihFRDhtcYmbAT6KU3J3v0= +github.com/LostImagin4tion/atlas v0.0.33/go.mod h1:FtOd0Ry45l3FeDVGVm8tf2SFWg3vHDztylE0eE3EWQ8= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 091c125ecf..6bdf330b34 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -132,6 +132,29 @@ func TestPostgres(t *testing.T) { } } +func TestYDB(t *testing.T) { + t.Parallel() + + ydbOpts := enttest.WithMigrateOptions( + migrate.WithDropIndex(true), + migrate.WithDropColumn(true), + migrate.WithForeignKeys(false), + sqlschema.WithSkipChanges(sqlschema.ModifyColumn), + ) + + client := enttest.Open(t, dialect.YDB, "grpc://localhost:2136/local", ydbOpts) + client = client.Debug() + defer client.Close() + + for _, tt := range tests { + name := runtime.FuncForPC(reflect.ValueOf(tt).Pointer()).Name() + t.Run(name[strings.LastIndex(name, ".")+1:], func(t *testing.T) { + drop(t, client) + tt(t, client) + }) + } +} + var ( opts = enttest.WithMigrateOptions( migrate.WithDropIndex(true), @@ -328,6 +351,10 @@ func Sanity(t *testing.T, client *ent.Client) { } func Upsert(t *testing.T, client *ent.Client) { + // YDB's UPSERT has different semantics. + // it only resolves conflicts on primary key and replaces the entire row. + skip(t, "YDB") + ctx := context.Background() u := client.User.Create().SetName("Ariel").SetAge(30).SetPhone("0000").SaveX(ctx) require.Equal(t, "static", u.Address, "address was set by default func") @@ -740,8 +767,21 @@ func Select(t *testing.T, client *ent.Client) { // Append the "users_count" column to the selected columns. AppendSelect( sql.As(sql.Count(t.C(group.UsersPrimaryKey[1])), "users_count"), - ). - GroupBy(s.C(group.FieldID)) + ) + + // YDB requires all non-key and non-aggregated columns to be in GROUP BY + if s.Dialect() == dialect.YDB { + s.GroupBy( + s.C(group.FieldID), + s.C(group.FieldActive), + s.C(group.FieldExpire), + s.C(group.FieldType), + s.C(group.FieldMaxUsers), + s.C(group.FieldName), + ) + } else { + s.GroupBy(s.C(group.FieldID)) + } }). ScanX(ctx, &gs) require.Len(gs, 2) @@ -786,7 +826,7 @@ func Select(t *testing.T, client *ent.Client) { // Execute custom update modifier. client.User.Update(). Modify(func(u *sql.UpdateBuilder) { - u.Set(user.FieldName, sql.Expr(fmt.Sprintf("UPPER(%s)", user.FieldName))) + u.Set(user.FieldName, sql.UpperExpr(user.FieldName)) }). ExecX(ctx) require.True(allUpper(), "at names must be upper-cased") @@ -827,8 +867,11 @@ func Select(t *testing.T, client *ent.Client) { } // Order by random value should compile a valid query. - _, err = client.User.Query().Order(sql.OrderByRand()).All(ctx) - require.NoError(err) + // YDB doesn't support ORDER BY with constant expressions like Random(seed). + if client.Dialect() != dialect.YDB { + _, err = client.User.Query().Order(sql.OrderByRand()).All(ctx) + require.NoError(err) + } } func Aggregate(t *testing.T, client *ent.Client) { @@ -1069,8 +1112,12 @@ func Delete(t *testing.T, client *ent.Client) { info := client.GroupInfo.Create().SetDesc("group info").SaveX(ctx) hub := client.Group.Create().SetInfo(info).SetName("GitHub").SetExpire(time.Now().Add(time.Hour)).SaveX(ctx) - err = client.GroupInfo.DeleteOne(info).Exec(ctx) - require.True(ent.IsConstraintError(err)) + + // YDB doesn't have foreign keys constraints + if client.Dialect() != dialect.YDB { + err = client.GroupInfo.DeleteOne(info).Exec(ctx) + require.True(ent.IsConstraintError(err)) + } // Group.DeleteOneID(id).Where(...), is identical to Group.Delete().Where(group.ID(id), ...), // but, in case the OpDelete is not an allowed operation, the DeleteOne can be used with Where. @@ -1086,18 +1133,26 @@ func Delete(t *testing.T, client *ent.Client) { Where(group.ExpireLT(time.Now())). Exec(ctx) require.True(ent.IsNotFound(err)) + hub.Update().SetExpire(time.Now().Add(-time.Hour)).ExecX(ctx) + client.Group.DeleteOne(hub). Where(group.ExpireLT(time.Now())). ExecX(ctx) // The behavior described above it also applied to UpdateOne. - hub = client.Group.Create().SetInfo(info).SetName("GitHub").SetExpire(time.Now().Add(time.Hour)).SaveX(ctx) + hub = client.Group.Create(). + SetInfo(info). + SetName("GitHub"). + SetExpire(time.Now().Add(time.Hour)). + SaveX(ctx) + err = hub.Update(). SetActive(false). - SetExpire(time.Time{}). + SetExpire(time.Unix(0, 0)). Where(group.ExpireLT(time.Now())). // Expired. Exec(ctx) + require.True(ent.IsNotFound(err)) } @@ -1625,13 +1680,22 @@ func UniqueConstraint(t *testing.T, client *ent.Client) { cm1 := client.Comment.Create().SetUniqueInt(42).SetUniqueFloat(math.Pi).SaveX(ctx) err = client.Comment.Create().SetUniqueInt(42).SetUniqueFloat(math.E).Exec(ctx) require.Error(err) - err = client.Comment.Create().SetUniqueInt(7).SetUniqueFloat(math.Pi).Exec(ctx) - require.Error(err) + + // YDB doesn't support unique indexes on float columns. + if client.Dialect() != dialect.YDB { + err = client.Comment.Create().SetUniqueInt(7).SetUniqueFloat(math.Pi).Exec(ctx) + require.Error(err) + } + client.Comment.Create().SetUniqueInt(7).SetUniqueFloat(math.E).ExecX(ctx) err = cm1.Update().SetUniqueInt(7).Exec(ctx) require.Error(err) - err = cm1.Update().SetUniqueFloat(math.E).Exec(ctx) - require.Error(err) + + // YDB doesn't support unique indexes on float columns. + if client.Dialect() != dialect.YDB { + err = cm1.Update().SetUniqueFloat(math.E).Exec(ctx) + require.Error(err) + } t.Log("unique constraint on time fields") now := time.Now() @@ -1666,7 +1730,7 @@ func Tx(t *testing.T, client *ent.Client) { m.On("onRollback", nil).Once() defer m.AssertExpectations(t) tx.OnRollback(m.rHook()) - tx.Node.Create().ExecX(ctx) + tx.Node.Create().SetValue(0).ExecX(ctx) require.NoError(t, tx.Rollback()) require.Zero(t, client.Node.Query().CountX(ctx), "rollback should discard all changes") }) @@ -1683,13 +1747,13 @@ func Tx(t *testing.T, client *ent.Client) { return err }) }) - nde := tx.Node.Create().SaveX(ctx) + node := tx.Node.Create().SetValue(0).SaveX(ctx) require.NoError(t, tx.Commit()) require.Error(t, tx.Commit(), "should return an error on the second call") require.NotZero(t, client.Node.Query().CountX(ctx), "commit should save all changes") - _, err = nde.QueryNext().Count(ctx) + _, err = node.QueryNext().Count(ctx) require.Error(t, err, "should not be able to query after tx was closed") - require.Zero(t, nde.Unwrap().QueryNext().CountX(ctx), "should be able to query the entity after wrap") + require.Zero(t, node.Unwrap().QueryNext().CountX(ctx), "should be able to query the entity after wrap") }) t.Run("Nested", func(t *testing.T) { tx, err := client.Tx(ctx) @@ -1703,8 +1767,19 @@ func Tx(t *testing.T, client *ent.Client) { require.NoError(t, tx.Rollback()) }) t.Run("TxOptions Rollback", func(t *testing.T) { - skip(t, "SQLite") - tx, err := client.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + skip(t, "SQLite", "YDB") + + var txOptions sql.TxOptions + if client.Dialect() == dialect.YDB { + txOptions = sql.TxOptions{ + Isolation: stdsql.LevelSnapshot, + ReadOnly: true, + } + } else { + txOptions = sql.TxOptions{ReadOnly: true} + } + + tx, err := client.BeginTx(ctx, &txOptions) require.NoError(t, err) var m mocker m.On("onRollback", nil).Once() @@ -1721,10 +1796,38 @@ func Tx(t *testing.T, client *ent.Client) { require.Error(t, err, "expect creation to fail in read-only tx") require.NoError(t, tx.Rollback()) }) + t.Run("YDB TxOptions Rollback", func(t *testing.T) { + if client.Dialect() != dialect.YDB { + t.Skip("YDB-specific test") + } + + tx, err := client.BeginTx(ctx, &sql.TxOptions{ + Isolation: stdsql.LevelSnapshot, + ReadOnly: true, + }) + require.NoError(t, err) + + err = tx.Item.Create().Exec(ctx) + require.Error(t, err, "expect creation to fail in read-only tx") + + // YDB implicitly invalidates transaction so Rollback() should return err + err = tx.Rollback() + require.Error(t, err) + require.Contains(t, err.Error(), "Transaction not found") + }) t.Run("TxOptions Commit", func(t *testing.T) { skip(t, "SQLite") - tx, err := client.BeginTx(ctx, &sql.TxOptions{Isolation: stdsql.LevelReadCommitted}) + + var txOptions sql.TxOptions + if client.Dialect() == dialect.YDB { + txOptions = sql.TxOptions{Isolation: stdsql.LevelSerializable} + } else { + txOptions = sql.TxOptions{Isolation: stdsql.LevelReadCommitted} + } + + tx, err := client.BeginTx(ctx, &txOptions) require.NoError(t, err) + var m mocker m.On("onCommit", nil).Once() defer m.AssertExpectations(t) @@ -2164,12 +2267,29 @@ func NoSchemaChanges(t *testing.T, client *ent.Client) { }) tables, err := sqlschema.CopyTables(migrate.Tables) require.NoError(t, err) + + opts := []sqlschema.MigrateOption{ + migrate.WithDropIndex(true), + migrate.WithDropColumn(true), + } + if strings.Contains(t.Name(), "YDB") { + opts = append( + opts, + migrate.WithForeignKeys(false), + sqlschema.WithSkipChanges(sqlschema.ModifyColumn), + ) + } + err = migrate.Create( context.Background(), - migrate.NewSchema(&sqlschema.WriteDriver{Writer: w, Driver: client.Driver()}), + migrate.NewSchema( + &sqlschema.WriteDriver{ + Driver: client.Driver(), + Writer: w, + }, + ), tables, - migrate.WithDropIndex(true), - migrate.WithDropColumn(true), + opts..., ) require.NoError(t, err) } @@ -2326,6 +2446,8 @@ func CreateBulk(t *testing.T, client *ent.Client) { } func ConstraintChecks(t *testing.T, client *ent.Client) { + skip(t, "YDB") + var cerr *ent.ConstraintError err := client.Pet.Create().SetName("orphan").SetOwnerID(0).Exec(context.Background()) require.True(t, errors.As(err, &cerr)) @@ -2340,7 +2462,7 @@ func ConstraintChecks(t *testing.T, client *ent.Client) { } func Lock(t *testing.T, client *ent.Client) { - skip(t, "SQLite", "MySQL/5", "Maria/10.2") + skip(t, "SQLite", "MySQL/5", "Maria/10.2", "YDB") ctx := context.Background() xabi := client.Pet.Create().SetName("Xabi").SaveX(ctx) @@ -2399,6 +2521,7 @@ func Lock(t *testing.T, client *ent.Client) { } func ExtValueScan(t *testing.T, client *ent.Client) { + skip(t, "YDB") ctx := context.Background() u, err := url.Parse("https://entgo.io") require.NoError(t, err) diff --git a/entc/integration/relation_test.go b/entc/integration/relation_test.go index 9350d6d060..9c852230f8 100644 --- a/entc/integration/relation_test.go +++ b/entc/integration/relation_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "entgo.io/ent/dialect" "entgo.io/ent/entc/integration/ent" "entgo.io/ent/entc/integration/ent/card" "entgo.io/ent/entc/integration/ent/group" @@ -166,6 +167,12 @@ func O2OSameType(t *testing.T, client *ent.Client) { require.Equal(2, client.Node.Query().CountX(ctx), "linked-list should have 2 nodes") t.Log("delete assoc should delete inverse edge") + + // YDB doesn't have FK constraints, manually clear reference before delete. + if client.Dialect() == dialect.YDB { + sec.Update().ClearPrev().ExecX(ctx) + } + client.Node.DeleteOne(head).ExecX(ctx) require.Zero(sec.QueryPrev().CountX(ctx), "second node should be the head now") require.Zero(sec.QueryNext().CountX(ctx), "second node should be the head now") @@ -245,16 +252,24 @@ func O2OSameType(t *testing.T, client *ent.Client) { require.Zero(head.QueryNext().QueryNext().Where(node.ValueGT(10)).QueryNext().QueryNext().QueryNext().CountX(ctx)) t.Log("delete all nodes except the head") + + // YDB doesn't have FK constraints, clear stale reference before delete. + if client.Dialect() == dialect.YDB { + head.Update().ClearNext().ExecX(ctx) + } + client.Node.Delete().Where(node.ValueGT(1)).ExecX(ctx) head = client.Node.Query().OnlyX(ctx) - t.Log("node points to itself (circular linked-list with 1 node)") - head.Update().SetNext(head).SaveX(ctx) - require.Equal(head.ID, head.QueryPrev().OnlyIDX(ctx)) - require.Equal(head.ID, head.QueryNext().OnlyIDX(ctx)) - head.Update().ClearNext().SaveX(ctx) - require.Zero(head.QueryPrev().CountX(ctx)) - require.Zero(head.QueryNext().CountX(ctx)) + if client.Dialect() != dialect.YDB { + t.Log("node points to itself (circular linked-list with 1 node)") + head.Update().SetNext(head).SaveX(ctx) + require.Equal(head.ID, head.QueryPrev().OnlyIDX(ctx)) + require.Equal(head.ID, head.QueryNext().OnlyIDX(ctx)) + head.Update().ClearNext().SaveX(ctx) + require.Zero(head.QueryPrev().CountX(ctx)) + require.Zero(head.QueryNext().CountX(ctx)) + } } // Demonstrate a O2O relation between two instances of the same type, where the relation @@ -785,6 +800,12 @@ func M2MSelfRef(t *testing.T, client *ent.Client) { require.Equal(2, client.User.Query().Where(user.HasFriends()).CountX(ctx)) t.Log("delete inverse should delete association") + + // YDB doesn't have FK constraints, manually clear references before delete. + if client.Dialect() == dialect.YDB { + bar.Update().ClearFriends().ExecX(ctx) + } + client.User.DeleteOne(bar).ExecX(ctx) require.False(foo.QueryFriends().ExistX(ctx)) require.Zero(client.User.Query().Where(user.HasFriends()).CountX(ctx)) @@ -925,6 +946,12 @@ func M2MSameType(t *testing.T, client *ent.Client) { require.Equal(1, client.User.Query().Where(user.HasFollowing()).CountX(ctx)) t.Log("delete inverse should delete association") + + // YDB doesn't have FK constraints, manually clear M2M references before delete. + if client.Dialect() == dialect.YDB { + bar.Update().ClearFollowing().ExecX(ctx) + } + client.User.DeleteOne(bar).ExecX(ctx) require.False(foo.QueryFollowers().ExistX(ctx)) require.Zero(client.User.Query().Where(user.HasFollowers()).CountX(ctx)) @@ -1067,6 +1094,12 @@ func M2MTwoTypes(t *testing.T, client *ent.Client) { require.Equal(1, hub.QueryUsers().CountX(ctx)) t.Log("delete inverse should delete association") + + // YDB doesn't have FK constraints, manually clear M2M references before delete. + if client.Dialect() == dialect.YDB { + hub.Update().ClearUsers().ExecX(ctx) + } + client.Group.DeleteOne(hub).ExecX(ctx) require.False(foo.QueryGroups().ExistX(ctx)) require.Zero(client.User.Query().Where(user.HasGroups()).CountX(ctx)) @@ -1082,6 +1115,12 @@ func M2MTwoTypes(t *testing.T, client *ent.Client) { require.Equal(1, client.Group.Query().Where(group.HasUsers()).CountX(ctx)) t.Log("delete assoc should delete inverse as well") + + // YDB doesn't have FK constraints, manually clear M2M references before delete. + if client.Dialect() == dialect.YDB { + foo.Update().ClearGroups().ExecX(ctx) + } + client.User.DeleteOne(foo).ExecX(ctx) require.False(hub.QueryUsers().ExistX(ctx)) require.Zero(client.User.Query().Where(user.HasGroups()).CountX(ctx)) diff --git a/entc/integration/type_test.go b/entc/integration/type_test.go index 62c0693d1a..611a441ddb 100644 --- a/entc/integration/type_test.go +++ b/entc/integration/type_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/entc/integration/ent" "entgo.io/ent/entc/integration/ent/fieldtype" @@ -105,7 +106,7 @@ func Types(t *testing.T, client *ent.Client) { require.Equal(role.Admin, ft.Role) require.Equal(role.High, ft.Priority) require.NoError(err) - dt, err := time.Parse(time.RFC3339, "1906-01-02T00:00:00+00:00") + dt, err := time.Parse(time.RFC3339, "1976-01-02T00:00:00+00:00") require.NoError(err) require.Equal(schema.Pair{K: []byte("K"), V: []byte("V")}, ft.Pair) require.Equal(&schema.Pair{K: []byte("K"), V: []byte("V")}, ft.NilPair) @@ -120,12 +121,15 @@ func Types(t *testing.T, client *ent.Client) { require.Equal("127.0.0.1", ft.LinkOtherFunc.String()) require.False(ft.DeletedAt.Time.IsZero()) - ft = client.FieldType.UpdateOne(ft).AddOptionalUint64(10).SaveX(ctx) - require.EqualValues(10, ft.OptionalUint64) - ft = client.FieldType.UpdateOne(ft).AddOptionalUint64(20).SetOptionalUint64(5).SaveX(ctx) - require.EqualValues(5, ft.OptionalUint64) - ft = client.FieldType.UpdateOne(ft).AddOptionalUint64(-5).SaveX(ctx) - require.Zero(ft.OptionalUint64) + // YDB: Add operations on unsigned fields use int64 values, but YDB requires exact type match. + if client.Dialect() != dialect.YDB { + ft = client.FieldType.UpdateOne(ft).AddOptionalUint64(10).SaveX(ctx) + require.EqualValues(10, ft.OptionalUint64) + ft = client.FieldType.UpdateOne(ft).AddOptionalUint64(20).SetOptionalUint64(5).SaveX(ctx) + require.EqualValues(5, ft.OptionalUint64) + ft = client.FieldType.UpdateOne(ft).AddOptionalUint64(-5).SaveX(ctx) + require.Zero(ft.OptionalUint64) + } err = client.FieldType.Create(). SetInt(1). @@ -136,6 +140,7 @@ func Types(t *testing.T, client *ent.Client) { SetRawData(make([]byte, 40)). Exec(ctx) require.Error(err, "MaxLen validator should reject this operation") + err = client.FieldType.Create(). SetInt(1). SetInt8(8). @@ -145,7 +150,8 @@ func Types(t *testing.T, client *ent.Client) { SetRawData(make([]byte, 2)). Exec(ctx) require.Error(err, "MinLen validator should reject this operation") - ft = ft.Update(). + + ftUpdate := ft.Update(). SetInt(1). SetInt8(math.MaxInt8). SetInt16(math.MaxInt16). @@ -173,9 +179,13 @@ func Types(t *testing.T, client *ent.Client) { SetMAC(schema.MAC{HardwareAddr: mac}). SetPair(schema.Pair{K: []byte("K1"), V: []byte("V1")}). SetNilPair(&schema.Pair{K: []byte("K1"), V: []byte("V1")}). - SetStringArray([]string{"qux"}). - AddBigInt(bigint). - SaveX(ctx) + SetStringArray([]string{"qux"}) + + // YDB: big_int is stored as Utf8, so Add operation doesn't work + if client.Dialect() != dialect.YDB { + ftUpdate.AddBigInt(bigint) + } + ft = ftUpdate.SaveX(ctx) require.Equal(int8(math.MaxInt8), ft.OptionalInt8) require.Equal(int16(math.MaxInt16), ft.OptionalInt16) @@ -204,7 +214,11 @@ func Types(t *testing.T, client *ent.Client) { require.EqualValues([]string{"qux"}, ft.StringArray) require.Nil(ft.NillableUUID) require.Equal(uuid.UUID{}, ft.OptionalUUID) - require.Equal("2000", ft.BigInt.String()) + + if client.Dialect() != dialect.YDB { + require.Equal("2000", ft.BigInt.String()) + } + require.EqualValues(100, ft.Int64, "UpdateDefault sets the value to 100") require.EqualValues(100, ft.Duration, "UpdateDefault sets the value to 100ns") require.False(ft.DeletedAt.Time.IsZero()) @@ -215,6 +229,7 @@ func Types(t *testing.T, client *ent.Client) { client.Task.Create().SetPriority(task.PriorityHigh), ).Exec(ctx) require.NoError(err) + err = client.Task.Create().SetPriority(task.Priority(10)).Exec(ctx) require.Error(err) diff --git a/examples/go.mod b/examples/go.mod index 88748103b7..13d183ff1c 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -50,4 +50,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace ariga.io/atlas => github.com/LostImagin4tion/atlas v0.0.18 +replace ariga.io/atlas => github.com/LostImagin4tion/atlas v0.0.33 diff --git a/examples/go.sum b/examples/go.sum index 3f5bdb7f00..e84eec833f 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -449,8 +449,8 @@ github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/GoogleCloudPlatform/cloudsql-proxy v1.33.1/go.mod h1:n3KDPrdaY2p9Nr0B1allAdjYArwIpXQcitNbsS/Qiok= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= -github.com/LostImagin4tion/atlas v0.0.18 h1:RrLLU6zEXuRZAih3slblKFZ/lPLUDZ+wrHaTrxILqrA= -github.com/LostImagin4tion/atlas v0.0.18/go.mod h1:Rco1malutATQGeWEoYFzurfzIvs+galayoZ0+Pz4als= +github.com/LostImagin4tion/atlas v0.0.33 h1:RgcQhGG0MZDwheuFRiZu47ihFRDhtcYmbAT6KU3J3v0= +github.com/LostImagin4tion/atlas v0.0.33/go.mod h1:FtOd0Ry45l3FeDVGVm8tf2SFWg3vHDztylE0eE3EWQ8= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= diff --git a/go.mod b/go.mod index b5843804c1..c3d32e35ae 100644 --- a/go.mod +++ b/go.mod @@ -64,4 +64,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace ariga.io/atlas => github.com/LostImagin4tion/atlas v0.0.18 +replace ariga.io/atlas => github.com/LostImagin4tion/atlas v0.0.33 diff --git a/go.sum b/go.sum index 6a778248bb..5f73949be5 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= -github.com/LostImagin4tion/atlas v0.0.18 h1:RrLLU6zEXuRZAih3slblKFZ/lPLUDZ+wrHaTrxILqrA= -github.com/LostImagin4tion/atlas v0.0.18/go.mod h1:Rco1malutATQGeWEoYFzurfzIvs+galayoZ0+Pz4als= +github.com/LostImagin4tion/atlas v0.0.33 h1:RgcQhGG0MZDwheuFRiZu47ihFRDhtcYmbAT6KU3J3v0= +github.com/LostImagin4tion/atlas v0.0.33/go.mod h1:FtOd0Ry45l3FeDVGVm8tf2SFWg3vHDztylE0eE3EWQ8= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=