diff --git a/builder.go b/builder.go index faf5276..a26587f 100644 --- a/builder.go +++ b/builder.go @@ -6,11 +6,13 @@ import ( "io" "slices" + "github.com/gomlx/stablehlo/internal/utils" "github.com/gomlx/stablehlo/types" + "github.com/gomlx/stablehlo/types/shardy" "github.com/pkg/errors" ) -// Builder is used to construct a StableHLO program. +// Builder is used to construct a StableHLO program (or "Module") // See details in New. type Builder struct { name string @@ -22,38 +24,20 @@ type Builder struct { // inlineUniqueID is a counter used to generate unique names for inlined functions values. inlineUniqueID int - // NumReplicas is the number of replicas for data parallelism. - NumReplicas int - // NumPartitions is the number of partitions for model parallelism. - NumPartitions int + // meshes used for Shardy. + meshes []*shardy.DeviceMesh + + // numReplicas is the number of replicas for data parallelism. + numReplicas int + + // numPartitions is the number of partitions for model parallelism. + numPartitions int // nextChannelID is the next ID to be assigned in channel handles. // It is just a Unique ID. nextChannelID int } -// NormalizeIdentifier converts the name of an identifier (function name or function input parameter -// name) to a valid one: only letters, digits and underscores are allowed. -// -// Invalid characters are replaced with underscores. -// If the name starts with a digit, it is prefixed with an underscore. -// -// The name is normalized in place. -func NormalizeIdentifier(name string) string { - result := make([]rune, 0, len(name)+1) - if name[0] >= '0' && name[0] <= '9' { - result = append(result, '_') - } - for _, r := range name { - if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' { - result = append(result, r) - } else { - result = append(result, '_') - } - } - return string(result) -} - // New creates a new Builder object holding a computation graph in construction. // // From a builder you can create functions. @@ -62,7 +46,7 @@ func NormalizeIdentifier(name string) string { // You have to define the "main" function for your StableHLO program: you can use Builder.Main to do so, or // Builder.NewFunction("main",...), it's the same. // -// Once you are all set, call Builder.Build and it will return the StableHLO program as a []byte that can +// Once you are all set, call Builder.Build and it will return the StableHLO program (or "Module") as a []byte that can // be used with PJRT. // // See github.com/gomlx/gopjrt for a Go API to PJRT. @@ -72,20 +56,6 @@ func New(name string) *Builder { } } -// WithNumReplicas sets the number of replicas (for data parallelism). -// This is added as an attribute to the StableHLO module. -func (b *Builder) WithNumReplicas(n int) *Builder { - b.NumReplicas = n - return b -} - -// WithNumPartitions sets the number of partitions (for model parallelism). -// This is added as an attribute to the StableHLO module. -func (b *Builder) WithNumPartitions(n int) *Builder { - b.NumPartitions = n - return b -} - // elementWriter represents elements of ToStableHLO that know how to write themselves. type elementWriter interface { Write(w io.Writer, indentation string) error @@ -134,11 +104,11 @@ const IndentationStep = " " // getModuleAttributes returns the attributes for the StableHLO module (StableHLO code) generated. func (b *Builder) getModuleAttributes() []string { var attributes []string - if b.NumReplicas > 0 { - attributes = append(attributes, fmt.Sprintf("stablehlo.num_replicas = %d", b.NumReplicas)) + if b.numReplicas > 0 { + attributes = append(attributes, fmt.Sprintf("stablehlo.num_replicas = %d", b.numReplicas)) } - if b.NumPartitions > 0 { - attributes = append(attributes, fmt.Sprintf(" stablehlo.num_partitions = %d", b.NumPartitions)) + if b.numPartitions > 0 { + attributes = append(attributes, fmt.Sprintf(" stablehlo.num_partitions = %d", b.numPartitions)) } return attributes } @@ -177,10 +147,22 @@ func (b *Builder) Write(writer io.Writer) error { } w("%s", attr) } - w(" }") + w("}") } w(" {\n") + // Write Shardy meshes if needed: + if len(b.meshes) > 0 { + namesUsed := utils.MakeSet[string](len(b.meshes)) + for _, mesh := range b.meshes { + if namesUsed.Has(mesh.Name()) { + return errors.Errorf("duplicate mesh name %q", mesh.Name()) + } + namesUsed.Insert(mesh.Name()) + w("%s%s\n", IndentationStep, mesh.ToStableHLO()) + } + } + // Write non-inline functions: var count int for _, fn := range b.functions { @@ -248,3 +230,72 @@ func (b *Builder) getChannelHandle(config *types.CollectiveConfig) literalStr { return literalStrF("#stablehlo.channel_handle", id, typ) } + +// WithNumReplicas sets the number of replicas (for data parallelism). +// This is added as an attribute to the StableHLO module. +// +// Consider using WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. +func (b *Builder) WithNumReplicas(n int) *Builder { + b.numReplicas = n + return b +} + +// WithNumPartitions sets the number of partitions (for model parallelism). +// This is added as an attribute to the StableHLO module. +// +// Consider using WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. +func (b *Builder) WithNumPartitions(n int) *Builder { + b.numPartitions = n + return b +} + +// WithShardy enables distributed computation across the devices selected by the given meshes. +// +// This is the recommended way to do distributed (across devices) computation, and given the inputs +// with sharded information, Shardy will automatically distribute the computation, without you needing +// to specify any of the collective operations. +// +// Usually, there is only one meshes. But one can split the devices in different meshes. The meshes overlap +// the concrete devices used. +// +// See details of XLA Shardy in [1] +// +// [1] https://github.com/openxla/shardy +func (b *Builder) WithShardy(meshes ...*shardy.DeviceMesh) *Builder { + b.meshes = meshes + b.WithNumReplicas(1) + numDevices := 0 + for _, mesh := range meshes { + numDevices = max(numDevices, mesh.NumDevices()) + } + b.WithNumPartitions(numDevices) + return b +} + +// Meshes returns the meshes configured with WithShardy. +func (b *Builder) Meshes() []*shardy.DeviceMesh { + return b.meshes +} + +// NewShardingSpec creates a new ShardingSpec using the first mesh configured with WithShardy. +// It returns nil if no mesh was not configured. +// +// This is a shortcut to NewShardingSpecByMeshIx(0). +func (b *Builder) NewShardingSpec() *shardy.ShardingSpec { + if len(b.meshes) == 0 { + return nil + } + return shardy.NewShardingSpec(b.meshes[0]) +} + +// NewShardingSpecByMeshIx creates a new ShardingSpec for the meshIdx (the order given by WithShardy). +// +// It may return nil if meshIdx is out of range. +func (b *Builder) NewShardingSpecByMeshIx(meshIdx int) *shardy.ShardingSpec { + if meshIdx < 0 || meshIdx >= len(b.meshes) { + return nil + } + return shardy.NewShardingSpec(b.meshes[meshIdx]) +} diff --git a/collective.go b/collective.go index 93a23c2..6f8dc3d 100644 --- a/collective.go +++ b/collective.go @@ -49,7 +49,10 @@ func formatReplicaGroups(groups [][]int) literalStr { // notice it's not the device numbers by the replica numbers (there is an indirection). // Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device // numbers. E.g., `[[0, 1, 2, 3]]`. -// - config: Optional configuration of the channels to be used. This is not needed for SPMD programs. +// - config: Optional configuration of the channels to be used. This is shouldn't be used for SPMD programs. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types.CollectiveConfig) (*Value, error) { op := optypes.CollectiveBroadcast fn := operand.fn @@ -77,7 +80,9 @@ func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types stmt := fn.addOp(op, outputShape, operand) stmt.Attributes = map[string]any{ "replica_groups": formatReplicaGroups(replicaGroups), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) } return stmt.Outputs[0], nil } @@ -95,6 +100,9 @@ func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types // Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device // numbers. E.g., `[[0, 1, 2, 3]]`. // - config: Optional configuration of the channels to be used. This is not needed for SPMD programs. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, config ...*types.CollectiveConfig) ( []*Value, error) { op := optypes.AllReduce @@ -122,7 +130,7 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, outputShapes, err := shapeinference.AllReduce( valuesToShapes(operands), valuesToShapes(computation.Inputs), - computation.Outputs, + valuesToShapes(computation.Outputs), replicaGroups) if err != nil { return nil, err @@ -138,7 +146,9 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, stmt := fn.addMultiOp(op, outputShapes, operands) stmt.Attributes = map[string]any{ "replica_groups": formatReplicaGroups(replicaGroups), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) } if cfg != nil && cfg.UseGlobalDeviceIDs { stmt.Attributes["use_global_device_ids"] = true @@ -153,6 +163,9 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, // - replicaGroups: A 2D array defining the communicating device groups. // - allGatherDim: The dimension along which to concatenate the operands. // - config: Optional configuration of the channels to be used. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config ...*types.CollectiveConfig) (*Value, error) { op := optypes.AllGather fn := operand.fn @@ -176,7 +189,9 @@ func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config . stmt.Attributes = map[string]any{ "replica_groups": formatReplicaGroups(replicaGroups), "all_gather_dim": int64(allGatherDim), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) } if cfg != nil && cfg.UseGlobalDeviceIDs { stmt.Attributes["use_global_device_ids"] = true @@ -193,6 +208,9 @@ func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config . // - concatDimension: The dimension along which to concatenate the received chunks. // - splitCount: The number of chunks to split the operand into. This must match the size of the replica groups. // - config: Optional configuration of the channels to be used. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func AllToAll(operand *Value, replicaGroups [][]int, splitDimension, concatDimension, splitCount int, config ...*types.CollectiveConfig) (*Value, error) { op := optypes.AllToAll fn := operand.fn @@ -218,7 +236,9 @@ func AllToAll(operand *Value, replicaGroups [][]int, splitDimension, concatDimen "split_dimension": int64(splitDimension), "concat_dimension": int64(concatDimension), "split_count": int64(splitCount), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) } if cfg != nil && cfg.UseGlobalDeviceIDs { stmt.Attributes["use_global_device_ids"] = true @@ -251,6 +271,9 @@ func formatSourceTargetPairs(pairs [][2]int) literalStr { // - operand: The tensor from the *local* replica. // - sourceTargetPairs: A 2D array where each inner array is a `[source, target]` pair of replica IDs. // - config: Optional configuration of the channels to be used. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*types.CollectiveConfig) (*Value, error) { op := optypes.CollectivePermute fn := operand.fn @@ -273,7 +296,12 @@ func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*ty stmt := fn.addOp(op, outputShape, operand) stmt.Attributes = map[string]any{ "source_target_pairs": formatSourceTargetPairs(sourceTargetPairs), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) + } + if cfg != nil && cfg.UseGlobalDeviceIDs { + stmt.Attributes["use_global_device_ids"] = true } return stmt.Outputs[0], nil } diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 044eda8..82b9ff4 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,7 +1,15 @@ -# v0.2.0: ... +# v0.2.0: (Release Candidate) Adding support for XLA Shardy - `Function.Input` and `Function.NamedInput`: (change in API) they now may return an error, if the name is duplicate. -- `AllReduce` now supports arbitrary number of inputs, to be reduced at once. +- Collective ops: + - `AllReduce` now supports arbitrary number of inputs, to be reduced at once. + - `channel_handle` attribute is not generated by default: it is not compatible with SPMD, and probably will + only be useful once `Send`/`Recv` are supported. +- Added XLA Shardy support: + - Added `shardy.DeviceMesh` and `shardy.ShardingSpec` types. + - Added `Builder.WithShardy(mesh)` + - Added `Function.NamedInputWithShardingAndAttributes()` + - Added `Function.ReturnWithShardingAndAttributes()` # v0.1.0: 2025/11/06 Multi-Device support diff --git a/function.go b/function.go index 19aed57..02e138d 100644 --- a/function.go +++ b/function.go @@ -4,12 +4,15 @@ import ( "fmt" "io" "reflect" + "slices" "strconv" + "strings" "github.com/gomlx/gopjrt/dtypes" "github.com/gomlx/stablehlo/internal/optypes" "github.com/gomlx/stablehlo/shapeinference" "github.com/gomlx/stablehlo/types/shapes" + "github.com/gomlx/stablehlo/types/shardy" "github.com/pkg/errors" ) @@ -23,8 +26,8 @@ type Function struct { // Inputs to the function. Inputs []*Value - // Outputs types of the function. - Outputs []shapes.Shape + // Outputs of the function. + Outputs []*Value // Statements in the function body. Statements []*Statement @@ -82,10 +85,28 @@ func (fn *Function) newValue(shape shapes.Shape) (v *Value) { // It picks a default unique name for the input parameter, you can also // provide a name with NamedInput. func (fn *Function) Input(shape shapes.Shape) (*Value, error) { + return fn.InputWithShardingAndAttributes(shape, nil, nil) +} + +// InputWithSharding creates a new input with the given sharding specification. +func (fn *Function) InputWithSharding(shape shapes.Shape, shardingSpec *shardy.ShardingSpec) (*Value, error) { + return fn.InputWithShardingAndAttributes(shape, shardingSpec, nil) +} + +// InputWithAttributes creates a new input with the given attributes. +func (fn *Function) InputWithAttributes(shape shapes.Shape, attributes map[string]any) (*Value, error) { + return fn.InputWithShardingAndAttributes(shape, nil, attributes) +} + +// InputWithShardingAndAttributes creates a new input with the given sharding specification and attributes. +func (fn *Function) InputWithShardingAndAttributes(shape shapes.Shape, shardingSpec *shardy.ShardingSpec, attributes map[string]any) (*Value, error) { rootFn := fn.findRootFn() - value, err := fn.NamedInput(fmt.Sprintf("arg%d", rootFn.nextArgID), shape) + value, err := fn.NamedInputWithShardingAndAttributes(fmt.Sprintf("arg%d", rootFn.nextArgID), shape, shardingSpec, attributes) + if err != nil { + return nil, err + } rootFn.nextArgID++ - return value, err + return value, nil } // NamedInput creates a new input parameter for a function with the given name -- it @@ -98,16 +119,62 @@ func (fn *Function) Input(shape shapes.Shape) (*Value, error) { // Names are used in the StableHLO code and may be helpful for debugging, but // otherwise have no impact. func (fn *Function) NamedInput(name string, shape shapes.Shape) (*Value, error) { + return fn.NamedInputWithShardingAndAttributes(name, shape, nil, nil) +} + +// NamedInputWithSharding creates a new input parameter for a function with the given name -- it +// must be a unique input name -- and sharding specification for distributed computation. +func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape, + shardingSpec *shardy.ShardingSpec) (*Value, error) { + return fn.NamedInputWithShardingAndAttributes(name, shape, shardingSpec, nil) +} + +// NamedInputWithAttributes creates a new input parameter for a function with the given name and attributes. +func (fn *Function) NamedInputWithAttributes(name string, shape shapes.Shape, + attributes map[string]any) (*Value, error) { + return fn.NamedInputWithShardingAndAttributes(name, shape, nil, attributes) +} + +// NamedInputWithShardingAndAttributes creates a new input parameter for a function with the given name -- it +// must be a unique input name -- and sharding specification for distributed computation. +// +// The shardingSpec can be nil: the default is a replicated input across all devices. +// +// The name is passed through ConvertToValidName, which converts any non-digit or ASCII letter to an underscore. +// +// Names with the format "%d" and "arg%d" are reserved for the default input parameters. +// +// Names are used in the StableHLO code and may be helpful for debugging, but otherwise have no impact. +func (fn *Function) NamedInputWithShardingAndAttributes(name string, shape shapes.Shape, + shardingSpec *shardy.ShardingSpec, attributes map[string]any) (*Value, error) { value := &Value{ - fn: fn, - name: ConvertToValidName(name), - shape: shape, + fn: fn, + name: ConvertToValidName(name), + shape: shape, + Attributes: attributes, } for i, input := range fn.Inputs { if input.name == value.name { return nil, errors.Errorf("duplicate input name %q with input #%d", value.name, i) } } + if shardingSpec != nil { + if value.Attributes == nil { + value.Attributes = make(map[string]any) + } + value.Attributes["sdy.sharding"] = literalStr(shardingSpec.ToValueAttribute(value.shape)) + if slices.Index(fn.Builder.meshes, shardingSpec.Mesh) == -1 { + meshesNames := make([]string, len(fn.Builder.meshes)) + for _, mesh := range fn.Builder.meshes { + meshesNames = append(meshesNames, mesh.Name()) + } + return nil, errors.Errorf("sharding spec meshe %q doesn't match any of the stablehlo.Builder meshes (%s)", + shardingSpec.Mesh, strings.Join(meshesNames, ", ")) + } + if err := shardingSpec.ValidateShape(shape); err != nil { + return nil, err + } + } fn.Inputs = append(fn.Inputs, value) return value, nil } @@ -180,28 +247,79 @@ func (fn *Function) ConstantFromFlatAndDimensions(flat any, dimensions ...int) ( // // There can be only one return statement from a Function, and it must be the last // operation of a function. -func (fn *Function) Return(firstValue *Value, otherValues ...*Value) error { +// +// If you are doing distributed computation, you can use WithReturnShardingSpecs to specify +// the sharding requirements for each of the return values. +func (fn *Function) Return(values ...*Value) error { + return fn.ReturnWithAttributes(values, nil) +} + +// ReturnWithShardingAndAttributes is a convenience function to call ReturnWithAttributes with the given sharding +// specifications. +// +// The shardingSpecs slice of ShardingSpecs must have the same length as the values slice. +// Each ShardingSpec can be nil, in which case the default sharding is replicated across all devices. +// If shardingSpecs is nil, this behaves just like ReturnWithAttributes. +// +// The attributes slice of maps can be set to nil if there are no attributes. +func (fn *Function) ReturnWithShardingAndAttributes(values []*Value, shardingSpecs []*shardy.ShardingSpec, + attributes []map[string]any) error { + if len(shardingSpecs) == 0 { + return fn.ReturnWithAttributes(values, attributes) + } + if len(values) != len(shardingSpecs) { + return errors.Errorf("Function.ReturnWithShardingAndAttributes requires the same number of values and sharding specs, got %d and %d", len(values), len(shardingSpecs)) + } + if len(attributes) == 0 { + attributes = make([]map[string]any, len(values)) + } + for i, shardingSpec := range shardingSpecs { + if shardingSpec != nil { + specLiteral := literalStr(shardingSpec.ToValueAttribute(values[i].shape)) + if attributes[i] == nil { + attributes[i] = map[string]any{"sdy.sharding": specLiteral} + } else { + attributes[i]["sdy.sharding"] = specLiteral + } + } + } + return fn.ReturnWithAttributes(values, attributes) +} + +// ReturnWithAttributes adds a return statement to the function with the given return values and attributes. +func (fn *Function) ReturnWithAttributes(values []*Value, attributes []map[string]any) error { if fn.Returned { return errors.Errorf("Function.Return already called for %q", fn.Name) } + if len(values) == 0 { + return errors.New("Function.Return requires at least one return value") + } + if len(attributes) > 0 && len(values) != len(attributes) { + return errors.Errorf( + "if attributes is defined (!=nil) Function.ReturnWithAttributes requires the same number of "+ + "values and attributes, got %d and %d", len(values), len(attributes)) + } fn.Returned = true - allValues := make([]*Value, 1, len(otherValues)+1) - allValues[0] = firstValue - allValues = append(allValues, otherValues...) - outputShapes := make([]shapes.Shape, len(allValues)) - for i, value := range allValues { + outputValues := make([]*Value, len(values)) + for i, value := range values { if value.fn != fn { return errors.New("Function.Return given values that are not owned by the function") } - outputShapes[i] = value.shape + outputValues[i] = &Value{ + fn: fn, + name: value.name, + shape: value.shape, + } + if len(attributes) > 0 { + outputValues[i].Attributes = attributes[i] + } } - fn.Outputs = outputShapes - + fn.Outputs = outputValues stmt := &Statement{ Builder: fn.Builder, Function: fn, OpType: optypes.FuncReturn, - Inputs: allValues, + Inputs: values, } fn.Statements = append(fn.Statements, stmt) return nil @@ -275,22 +393,25 @@ func (fn *Function) Write(writer io.Writer, indentation string) error { } we(input, nextIndent) w(": %s", input.shape.ToStableHLO()) + writeAttributes(writer, indentation, input.Attributes, w) } if isClosure { w(") :\n") } else if normalFunction { w(") -> ") - if len(fn.Outputs) > 1 { + encloseOutputInParenthesis := len(fn.Outputs) > 1 || (len(fn.Outputs) == 1 && len(fn.Outputs[0].Attributes) > 0) + if encloseOutputInParenthesis { w("(") } for i, output := range fn.Outputs { if i > 0 { w(", ") } - w("%s", output.ToStableHLO()) + w(output.shape.ToStableHLO()) + writeAttributes(writer, indentation, output.Attributes, w) } - if len(fn.Outputs) > 1 { + if encloseOutputInParenthesis { w(")") } w(" {\n") diff --git a/go.mod b/go.mod index 0f56f5e..c6a2de7 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.24.0 toolchain go1.24.6 require ( - github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b + github.com/gomlx/gopjrt v0.10.0-rc0 github.com/janpfeifer/must v0.2.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index 3da24ba..806970d 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b h1:2jh5hJmb6YzGWL+rKUDYUpMdcmeeA+jDkKpfmevxl8g= github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b/go.mod h1:c8UENVGnxIDdihEL5HinlAdgR7RxTbEPLBppiMQF1ew= +github.com/gomlx/gopjrt v0.10.0-rc0 h1:EpF+JJYl3AUvU5ToKSfsuFnSPBxkPjbor93Ziak7OGA= +github.com/gomlx/gopjrt v0.10.0-rc0/go.mod h1:c8UENVGnxIDdihEL5HinlAdgR7RxTbEPLBppiMQF1ew= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/janpfeifer/go-benchmarks v0.1.1 h1:gLLy07/JrOKSnMWeUxSnjTdhkglgmrNR2IBDnR4kRqw= diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 0000000..c56831d --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,24 @@ +package utils + +// NormalizeIdentifier converts the name of an identifier (function name or function input parameter +// name) to a valid one: only letters, digits, and underscores are allowed. +// +// Invalid characters are replaced with underscores. +// If the name starts with a digit, it is prefixed with an underscore. +func NormalizeIdentifier(name string) string { + if name == "" { + return "" + } + result := make([]rune, 0, len(name)+1) + if name[0] >= '0' && name[0] <= '9' { + result = append(result, '_') + } + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' { + result = append(result, r) + } else { + result = append(result, '_') + } + } + return string(result) +} diff --git a/ops.go b/ops.go index a9066af..2021c92 100644 --- a/ops.go +++ b/ops.go @@ -216,6 +216,31 @@ type DotGeneralBuilder struct { algorithm *types.DotGeneralAlgorithm } +// DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications +// for a general vector product -- a generalized "Einsum". Each axis can be: +// - Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions +// must match in lhs and rhs. +// - Crossed (default), in which case the output is the combination (concatenation) of the +// dimensions. +// - Contracted (contracting axes), where the output does multiply the values and reduce sum +// those dimensions. +// +// It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' +// non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. +// It provides the basic means of implementing Einsum. +// +// Because there are optional parameters, this function returns a DotGeneralBuilder that can +// be further configured. Call DotGeneralBuilder.Done to get the final DotGeneral node. +func Dot(lhs, rhs *Value) (*Value, error) { + if lhs.Shape().Rank() != 2 || rhs.Shape().Rank() != 2 { + return nil, errors.Errorf("Dot only supports rank-2 tensors, got %d and %d", lhs.Shape().Rank(), rhs.Shape().Rank()) + } + return DotGeneral( + lhs, []int{1}, nil, + rhs, []int{0}, nil, + ).Done() +} + // DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications // for a general vector product -- a generalized "Einsum". Each axis can be: // - Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions @@ -627,7 +652,7 @@ func MultiReduce(inputs, initialValues []*Value, reductionFn *Function, axes ... outputsShapes, err := shapeinference.Reduce( valuesToShapes(inputs), valuesToShapes(initialValues), - valuesToShapes(reductionFn.Inputs), reductionFn.Outputs, + valuesToShapes(reductionFn.Inputs), valuesToShapes(reductionFn.Outputs), axes) if err != nil { return nil, err @@ -845,7 +870,7 @@ func MultiScatter(inputs []*Value, scatterIndices *Value, updates []*Value, updateWindowAxes, insertedWindowAxes, inputBatchingAxes, scatterIndicesBatchingAxes, indexedInputAxes, indexVectorAxis, - updateComputationInputShapes, updateComputationFn.Outputs) + updateComputationInputShapes, valuesToShapes(updateComputationFn.Outputs)) if err != nil { return nil, err } @@ -1241,7 +1266,7 @@ func MultiReduceWindow(inputs, initialValues []*Value, reductionFn *Function, outputsShapes, err := shapeinference.ReduceWindow( valuesToShapes(inputs), valuesToShapes(initialValues), - valuesToShapes(reductionFn.Inputs), reductionFn.Outputs, + valuesToShapes(reductionFn.Inputs), valuesToShapes(reductionFn.Outputs), windowDimensions, strides, inputDilations, windowDilations, paddings) if err != nil { diff --git a/stablehlo.go b/stablehlo.go index 2e47c0b..33f6d2f 100644 --- a/stablehlo.go +++ b/stablehlo.go @@ -13,5 +13,16 @@ // See ToStableHLO documentation and specifications in https://openxla.org/stablehlo/spec package stablehlo +import "github.com/gomlx/stablehlo/internal/utils" + // Generates some trivial functions (binary and unary operators) automatically. //go:generate go run ./internal/cmd/ops_generator + +// NormalizeIdentifier converts the name of an identifier (function name or function input parameter +// name, etc.) to a valid one: only letters, digits, and underscores are allowed. +// +// Invalid characters are replaced with underscores. +// If the name starts with a digit, it is prefixed with an underscore. +func NormalizeIdentifier(name string) string { + return utils.NormalizeIdentifier(name) +} diff --git a/stablehlo_test.go b/stablehlo_test.go index 375a1bb..6ed5fc6 100644 --- a/stablehlo_test.go +++ b/stablehlo_test.go @@ -6,6 +6,7 @@ import ( "github.com/gomlx/gopjrt/dtypes" "github.com/gomlx/stablehlo/types/shapes" + "github.com/gomlx/stablehlo/types/shardy" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,6 +43,64 @@ func TestBuilder(t *testing.T) { } }) + t.Run("Sharding", func(t *testing.T) { + b := New(t.Name()) + mesh, err := shardy.NewDeviceMesh("mesh", []int{4, 2}, []string{"data", "model"}) + require.NoError(t, err) + err = mesh.SetLogicalDeviceAssignment(7, 6, 5, 4, 3, 2, 1, 0) + require.NoError(t, err) + b.WithShardy(mesh) + fn := b.Main() + + arg0 := must(fn.NamedInputWithShardingAndAttributes( + "arg0", + shapes.Make(dtypes.F32, 16, 128), + b.NewShardingSpec().AddShardedAxis("data"), + nil, + )) + arg1 := must(fn.NamedInputWithSharding( + "arg1", + shapes.Make(dtypes.F32, 128, 256), + b.NewShardingSpec().AddShardedAxis("model"), + )) + + tanh := must(Tanh(arg0)) + dot := must(Dot(tanh, arg1)) + err = fn.ReturnWithShardingAndAttributes( + []*Value{dot}, + []*shardy.ShardingSpec{ + b.NewShardingSpec().AddShardedAxis("data"), + }, + []map[string]any{ + {"jax.result_info": "result"}, + }) + require.NoError(t, err) + + program := string(must(b.Build())) + fmt.Printf("%s program:\n%s", t.Name(), program) + want := `module @TestBuilder_Sharding attributes {stablehlo.num_replicas = 1, stablehlo.num_partitions = 8} { + sdy.mesh @mesh = <["data"=4, "model"=2], device_ids=[7, 6, 5, 4, 3, 2, 1, 0]> + func.func @main(%arg0: tensor<16x128xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> }, %arg1: tensor<128x256xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"model"}, {}]> }) -> (tensor<16x256xf32> { + jax.result_info = "result", + sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> + }) { + %0 = "stablehlo.tanh"(%arg0) : (tensor<16x128xf32>) -> tensor<16x128xf32> + %1 = "stablehlo.dot_general"(%0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] +>, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32> + "stablehlo.return"(%1) : (tensor<16x256xf32>) -> () + } +} +` + require.Equal(t, want, program) + }) + t.Run("with inputs", func(t *testing.T) { builder := New(t.Name()) shape := shapes.Make(dtypes.Float64) diff --git a/statement.go b/statement.go index 10573a0..66561fe 100644 --- a/statement.go +++ b/statement.go @@ -3,6 +3,7 @@ package stablehlo import ( "fmt" "io" + "maps" "math" "reflect" "slices" @@ -103,31 +104,7 @@ func (s *Statement) Write(writer io.Writer, indentation string) error { } // Write attributes: - if len(s.Attributes) > 0 { - if len(s.Attributes) == 1 { - for key, value := range s.Attributes { - literalValue := literalToStableHLO(value) - if strings.Index(literalValue, "\n") == -1 { - w(" { %s = %s }", key, literalValue) - } else { - literalValue = strings.ReplaceAll(literalValue, "\n", "\n"+nextIndentation) - w(" {\n%s%s = %s\n }", nextIndentation, key, literalValue) - } - } - } else { - // One attribute per line: - w(" {") - first := true - for key, value := range s.Attributes { - if !first { - w(",") - } - first = false - w("\n%s%s = %s", nextIndentation, key, literalToStableHLO(value)) - } - w("\n%s}", indentation) - } - } + writeAttributes(writer, indentation, s.Attributes, w) // Write signature: w(" : (") @@ -160,6 +137,38 @@ func (s *Statement) Write(writer io.Writer, indentation string) error { return err } +// writeAttributes writes a map of attributes to the writer. +// The w function is the one provided by the caller to handle errors. +func writeAttributes(writer io.Writer, indentation string, attributes map[string]any, w func(format string, args ...any)) { + if len(attributes) == 0 { + return + } + nextIndentation := indentation + IndentationStep + if len(attributes) == 1 { + for key, value := range attributes { + literalValue := literalToStableHLO(value) + if strings.Index(literalValue, "\n") == -1 { + w(" { %s = %s }", key, literalValue) + } else { + literalValue = strings.ReplaceAll(literalValue, "\n", "\n"+nextIndentation) + w(" {\n%s%s = %s\n }", nextIndentation, key, literalValue) + } + } + } else { + // One attribute per line: + w(" {") + keys := slices.Collect(maps.Keys(attributes)) + slices.Sort(keys) + for i, key := range keys { + if i > 0 { + w(",") + } + w("\n%s%s = %s", nextIndentation, key, literalToStableHLO(attributes[key])) + } + w("\n%s}", indentation) + } +} + // hasToStableHLO is implemented by types that can be converted to a stablehlo string. type hasToStableHLO interface { ToStableHLO() string diff --git a/tests/gopjrt/shardy_test.go b/tests/gopjrt/shardy_test.go new file mode 100644 index 0000000..fbe941a --- /dev/null +++ b/tests/gopjrt/shardy_test.go @@ -0,0 +1,114 @@ +package gopjrt + +import ( + "fmt" + "testing" + + "github.com/gomlx/gopjrt/dtypes" + "github.com/gomlx/gopjrt/pjrt" + "github.com/gomlx/stablehlo" + "github.com/gomlx/stablehlo/types/shapes" + "github.com/gomlx/stablehlo/types/shardy" + "github.com/stretchr/testify/require" +) + +func TestShardy(t *testing.T) { + iterateClientsAndTest(t, testShardy) +} + +// compileAndExecute program with PJRT. All inputs are donated. +func shardyCompileAndExecute(t *testing.T, client *pjrt.Client, program []byte, + deviceAssignment []int, inputs ...*pjrt.Buffer) []*pjrt.Buffer { + loadedExec, err := client.Compile(). + WithStableHLO(program). + WithShardy(len(deviceAssignment)). + WithDeviceAssignment(deviceAssignment). + Done() + require.NoErrorf(t, err, "failed to compile program: \n%s", program) + defer func() { + err := loadedExec.Destroy() + if err != nil { + t.Errorf("failed to destroy loaded exec: %+v", err) + } + }() + outputBuffers, err := loadedExec.Execute(inputs...).DonateAll().Done() + require.NoErrorf(t, err, "failed to execute program: \n%s", program) + return outputBuffers +} + +func testShardy(t *testing.T, client *pjrt.Client) { + // We will test it with 2 devices. + const numReplicas = 2 + numDevices := client.NumDevices() + if numDevices < numReplicas { + t.Skipf("Skipping test: not enough devices: %d < %d", numDevices, numReplicas) + return + } + deviceAssignment := make([]int, numReplicas) + for i := range numReplicas { + deviceAssignment[i] = i + } + + t.Run("input-data-sharding", func(t *testing.T) { + mesh := must1(shardy.NewDeviceMesh("mesh", []int{2}, []string{"data"})) + // Reverse logical device assignment: notice it doesn't affect the order of the inputs. + must(mesh.SetLogicalDeviceAssignment(1, 0)) + builder := stablehlo.New(t.Name()).WithShardy(mesh) + fn := builder.Main() + x := must1(fn.NamedInputWithSharding("arg0", shapes.Make(dtypes.F32, 2, 3), + builder.NewShardingSpec().AddShardedAxis("data"))) + reductionFn := fn.Closure() + lhs := must1(reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32))) + rhs := must1(reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32))) + must(reductionFn.Return(must1(stablehlo.Add(lhs, rhs)))) + zero := must1(fn.ConstantFromScalar(float32(0))) + output := must1(stablehlo.Reduce(x, zero, reductionFn, 0, 1)) + must(fn.Return(output)) + program := must1(builder.Build()) + fmt.Printf("%s program:\n%s", t.Name(), program) + x0 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[0]). + FromFlatDataWithDimensions([]float32{0, 1, 2}, []int{1, 3}). + Done()) + x1 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[1]). + FromFlatDataWithDimensions([]float32{0, 0.1, 0.2}, []int{1, 3}). + Done()) + outputs := shardyCompileAndExecute(t, client, program, deviceAssignment, x0, x1) + requireBuffersEqual(t, []FlatAndDims{ + {[]float32{3.3}, nil}, + {[]float32{3.3}, nil}, + }, outputs) + }) + + t.Run("output-data-sharding", func(t *testing.T) { + mesh := must1(shardy.NewDeviceMesh("data_mesh", []int{2}, []string{"data"})) + builder := stablehlo.New(t.Name()).WithShardy(mesh) + fn := builder.NewFunction("main") + x := must1(fn.NamedInputWithSharding("arg0", shapes.Make(dtypes.F32, 2, 3), + builder.NewShardingSpec().AddShardedAxis("data"))) + reductionFn := fn.Closure() + lhs := must1(reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32))) + rhs := must1(reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32))) + must(reductionFn.Return(must1(stablehlo.Add(lhs, rhs)))) + zero := must1(fn.ConstantFromScalar(float32(0))) + output := must1(stablehlo.Reduce(x, zero, reductionFn, 1)) + must(fn.Return(output)) + program := must1(builder.Build()) + fmt.Printf("%s program:\n%s", t.Name(), program) + x0 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[0]). + FromFlatDataWithDimensions([]float32{0, 1, 2}, []int{1, 3}). + Done()) + x1 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[1]). + FromFlatDataWithDimensions([]float32{0, 0.1, 0.2}, []int{1, 3}). + Done()) + outputs := shardyCompileAndExecute(t, client, program, deviceAssignment, x0, x1) + requireBuffersEqual(t, []FlatAndDims{ + {[]float32{3}, []int{1}}, + {[]float32{0.3}, []int{1}}, + }, outputs) + }) + +} diff --git a/types/ops.go b/types/ops.go index e21dd4b..021e9f7 100644 --- a/types/ops.go +++ b/types/ops.go @@ -222,15 +222,15 @@ const ( // CollectiveConfig provides advanced, optional configuration for collective operations. // Pass this as the last (optional) argument to collective ops. type CollectiveConfig struct { + // ChannelType specifies the communication dimension. + // Defaults to CrossReplica (0). + ChannelType ChannelType + // ChannelID, if non-nil, forces a specific channel ID (the 'handle'). // If nil, a unique ID will be automatically generated. // This is **required** for MPMD (multi-program, multi-data) to manually link ops across programs. ChannelID *int - // ChannelType specifies the communication dimension. - // Defaults to CrossReplica (0). - ChannelType ChannelType - // UseGlobalDeviceIDs changes the interpretation of replica_groups // from replica IDs to global device IDs. // This only applies to AllReduce, not CollectiveBroadcast. diff --git a/types/shardy/devicemesh.go b/types/shardy/devicemesh.go new file mode 100644 index 0000000..37e08a9 --- /dev/null +++ b/types/shardy/devicemesh.go @@ -0,0 +1,292 @@ +// Package shardy provides the types needed to define a distributed computation topology. +// This is used for leveraging XLA Shardy [1] +// +// [1] https://github.com/openxla/shardy +package shardy + +import ( + "fmt" + "slices" + "strings" + + "github.com/gomlx/stablehlo/internal/utils" + "github.com/pkg/errors" +) + +// DeviceMesh defines the logical topology of a set of devices on a backend. +type DeviceMesh struct { + name string + + // axesNames are the names of the mesh axes. + axesNames []string + + // axesSizes defines the number of devices along each mesh axis. + axesSizes []int + + // nameToAxis maps axis names to their index. + nameToAxis map[string]int + + // numDevices is the total number of devices in the mesh. + numDevices int + + // logicalDeviceAssignment is the list of "logical" devices numbers in the mesh, in the order they appear in the + // mesh. + // These numbers are indices in the LogicalDeviceAssignment that will be used in the compilation of the program. + logicalDeviceAssignment []int +} + +// NewDeviceMesh creates a new logical topology of a set of devices. +// +// - name: the name of the mesh, it must be a valid StableHLO identifier (see stablehlo.NormalizeIdentifier). +// - axesSizes: defines the number of devices along each mesh axis, one value per axis. +// - axesNames: the names of the mesh axes. One value per axis. They must also be valid StableHLO identifiers +// (see stablehlo.NormalizeName). +// +// The default mapping of logical devices numbers to the mesh is sequential, starting from 0, but it can be +// changed with the DeviceMesh.SetLogicalDeviceAssignment() method. +func NewDeviceMesh(name string, axesSizes []int, axesNames []string) (*DeviceMesh, error) { + if len(axesSizes) != len(axesNames) { + return nil, errors.Errorf("axesSizes and axesNames must have the same length, got %d and %d", + len(axesSizes), len(axesNames)) + } + if len(axesSizes) == 0 { + return nil, errors.New("DeviceMesh axesSizes cannot be empty") + } + + // Normalize names: + if name != utils.NormalizeIdentifier(name) { + return nil, errors.Errorf( + "DeviceMesh name %q is not a valid StableHLO identifier, suggestion %q -- or use "+ + "stablehlo.NormalizeIdentifier()", name, utils.NormalizeIdentifier(name)) + } + axesNames = slices.Clone(axesNames) + for i, axisName := range axesNames { + if axesNames[i] != utils.NormalizeIdentifier(axisName) { + return nil, errors.Errorf( + "DeviceMesh axis name %q at index %d is not a valid StableHLO identifier, suggestion %q -- or use "+ + "stablehlo.NormalizeIdentifier()", axisName, i, utils.NormalizeIdentifier(axisName)) + } + } + + numDevices := 1 + nameToAxis := make(map[string]int, len(axesSizes)) + for i, name := range axesNames { + if name == "" { + return nil, errors.Errorf("DeviceMesh axis name at index %d cannot be empty", i) + } + if _, found := nameToAxis[name]; found { + return nil, errors.Errorf("DeviceMesh axis name %q is duplicated", name) + } + nameToAxis[name] = i + numDevices *= axesSizes[i] + } + + m := &DeviceMesh{ + name: name, + axesNames: axesNames, + axesSizes: axesSizes, + nameToAxis: nameToAxis, + numDevices: numDevices, + } + return m, nil +} + +func (m *DeviceMesh) Name() string { + return m.name +} + +// NumDevices returns the total number of devices in the mesh. +func (m *DeviceMesh) NumDevices() int { + return m.numDevices +} + +// Rank returns the number of axes in the mesh. +func (m *DeviceMesh) Rank() int { + return len(m.axesSizes) +} + +// AxesNames returns a copy of the mesh's axis names. +func (m *DeviceMesh) AxesNames() []string { + return slices.Clone(m.axesNames) +} + +// AxesSizes returns a copy of the mesh's axesSizes. +func (m *DeviceMesh) AxesSizes() []int { + shape := make([]int, len(m.axesSizes)) + copy(shape, m.axesSizes) + return shape +} + +// AxisSize returns the number of devices along the given mesh axis. +func (m *DeviceMesh) AxisSize(axisName string) (int, error) { + idx, found := m.nameToAxis[axisName] + if !found { + return 0, errors.Errorf("mesh axis %q not found", axisName) + } + return m.axesSizes[idx], nil +} + +// String implements the fmt.Stringer interface. +func (m *DeviceMesh) String() string { + var sb strings.Builder + sb.WriteString("DeviceMesh(axesSizes={") + for i, name := range m.axesNames { + if i > 0 { + sb.WriteString(", ") + } + _, _ = fmt.Fprintf(&sb, "%s: %d", name, m.axesSizes[i]) + } + sb.WriteString("})") + return sb.String() +} + +// SetLogicalDeviceAssignment sets the assignment of logical devices to the mesh. +// +// The length of devices must be equal to NumDevices(). And it should include all numbers from 0 to NumDevices()-1. +// +// It returns an error if logicalDeviceAssignment has invalid device numbers or len(devices) != NumDevices(). +func (m *DeviceMesh) SetLogicalDeviceAssignment(devices ...int) error { + if len(devices) == 0 { + m.logicalDeviceAssignment = nil + return nil + } + if len(devices) != m.numDevices { + return errors.Errorf("devices must have %d elements, got %d", m.numDevices, len(devices)) + } + seen := utils.MakeSet[int](m.numDevices) + for _, device := range devices { + if seen.Has(device) { + return errors.Errorf("physical device #%d is duplicated in mapping", device) + } + seen.Insert(device) + if device < 0 || device >= m.numDevices { + return errors.Errorf("devices must be between 0 and %d (NumDevices()-1), got device %d", + m.numDevices-1, device) + } + } + m.logicalDeviceAssignment = slices.Clone(devices) + return nil +} + +// LogicalDeviceAssignment returns the list of devices in the mesh, in the order they appear in the mesh. +// +// It can return nil if no assignment was set with SetLogicalDeviceAssignment() -- in which case it will +// default to a sequential assignment starting from 0. +func (m *DeviceMesh) LogicalDeviceAssignment() []int { + if m.logicalDeviceAssignment == nil { + return nil + } + return slices.Clone(m.logicalDeviceAssignment) +} + +// ComputeReplicaGroups returns the replica groups participating in some collective (distributed) operation given the +// axes along which the operation is performed. +// +// Each replica group (a []int) includes the device indices (from the LogicalDeviceAssignment) for the axes specified. +// The other axes will be split into different replica groups. +// +// Example: +// +// m := NewDeviceMesh([]int{2, 2}, []string{"batch", "data"}) +// batchGroups, _ := m.ComputeReplicaGroups([]string{"batch"}) // -> [][]int{{0, 2}, {1, 3}} +// dataGroups, _ := m.ComputeReplicaGroups([]string{"data"}) // -> [][]int{{0, 1}, {2, 3}} +// globalGroups, _ := m.ComputeReplicaGroups([]string{"batch", "data"}) // -> [][]int{{0, 1, 2, 3}} +func (m *DeviceMesh) ComputeReplicaGroups(axes []string) ([][]int, error) { + // Find indices of the specified axes + axisIndices := make([]int, 0, len(axes)) + axisSet := utils.MakeSet[int](len(axes)) + for _, axis := range axes { + if idx, found := m.nameToAxis[axis]; found { + if axisSet.Has(idx) { + return nil, errors.Errorf("axis %q is duplicated: each axis can only appear once", axis) + } + axisIndices = append(axisIndices, idx) + axisSet.Insert(idx) + } else { + return nil, errors.Errorf("axis %q not found in mesh", axis) + } + } + + // Create indices for each axis dimension + nonAxisIndices := make([]int, 0, len(m.axesSizes)-len(axisIndices)) + for i := range m.axesSizes { + if !slices.Contains(axisIndices, i) { + nonAxisIndices = append(nonAxisIndices, i) + } + } + + // Calculate the size of each group and number of groups + groupSize := 1 + for _, idx := range axisIndices { + groupSize *= m.axesSizes[idx] + } + numGroups := m.numDevices / groupSize + + // Initialize the result + groups := make([][]int, numGroups) + for i := range groups { + groups[i] = make([]int, groupSize) + } + + // Fill in the groups + for flatIdx := 0; flatIdx < m.numDevices; flatIdx++ { + // Convert flat index to per-axis indices + indices := make([]int, len(m.axesSizes)) + remaining := flatIdx + for i := len(m.axesSizes) - 1; i >= 0; i-- { + indices[i] = remaining % m.axesSizes[i] + remaining /= m.axesSizes[i] + } + + // Calculate group index from non-axis indices + groupIdx := 0 + multiplier := 1 + for i := len(nonAxisIndices) - 1; i >= 0; i-- { + axisIdx := nonAxisIndices[i] + groupIdx += indices[axisIdx] * multiplier + multiplier *= m.axesSizes[axisIdx] + } + + // Calculate position within group from axis indices + posInGroup := 0 + multiplier = 1 + for i := len(axisIndices) - 1; i >= 0; i-- { + axisIdx := axisIndices[i] + posInGroup += indices[axisIdx] * multiplier + multiplier *= m.axesSizes[axisIdx] + } + + groups[groupIdx][posInGroup] = flatIdx + } + + return groups, nil +} + +// ToStableHLO returns the StableHLO representation of the mesh, as it should be used in the module body. +// E.g.: sdy.mesh @mesh = <["data"=4, "model"=2]> +func (m *DeviceMesh) ToStableHLO() string { + var buf strings.Builder + w := func(format string, args ...any) { + buf.WriteString(fmt.Sprintf(format, args...)) + } + w("sdy.mesh @%s = <[", m.name) + for i, axisName := range m.axesNames { + if i > 0 { + w(", ") + } + w("%q=%d", axisName, m.axesSizes[i]) + } + w("]") + if len(m.logicalDeviceAssignment) > 0 { + w(", device_ids=[") + for i, logicalDeviceId := range m.logicalDeviceAssignment { + if i > 0 { + w(", ") + } + w("%d", logicalDeviceId) + } + w("]") + } + w(">") + return buf.String() +} diff --git a/types/shardy/devicemesh_test.go b/types/shardy/devicemesh_test.go new file mode 100644 index 0000000..f8dda8a --- /dev/null +++ b/types/shardy/devicemesh_test.go @@ -0,0 +1,374 @@ +package shardy_test + +import ( + "testing" + + "github.com/gomlx/stablehlo/types/shardy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDeviceMesh(t *testing.T) { + t.Run("NewDeviceMesh_Valid", func(t *testing.T) { + tests := []struct { + name string + shape []int + axisNames []string + wantRank int + wantNum int + wantStableHLO string + }{ + { + name: "1D mesh", + shape: []int{8}, + axisNames: []string{"replica"}, + wantRank: 1, + wantNum: 8, + wantStableHLO: `sdy.mesh @mesh = <["replica"=8]>`, + }, + { + name: "2D mesh", + shape: []int{2, 4}, + axisNames: []string{"x", "y"}, + wantRank: 2, + wantNum: 8, + wantStableHLO: `sdy.mesh @mesh = <["x"=2, "y"=4]>`, + }, + { + name: "3D mesh", + shape: []int{2, 2, 2}, + axisNames: []string{"x", "y", "z"}, + wantRank: 3, + wantNum: 8, + wantStableHLO: `sdy.mesh @mesh = <["x"=2, "y"=2, "z"=2]>`, + }, + { + name: "single device", + shape: []int{1}, + axisNames: []string{"replica"}, + wantRank: 1, + wantNum: 1, + wantStableHLO: `sdy.mesh @mesh = <["replica"=1]>`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) + require.NoError(t, err) + assert.NotNil(t, mesh) + assert.Equal(t, tt.wantRank, mesh.Rank()) + assert.Equal(t, tt.wantNum, mesh.NumDevices()) + assert.Equal(t, tt.wantStableHLO, mesh.ToStableHLO()) + }) + } + }) + + t.Run("NewDeviceMesh_Errors", func(t *testing.T) { + tests := []struct { + name string + shape []int + axisNames []string + wantErr string + }{ + { + name: "mismatched lengths", + shape: []int{2, 4}, + axisNames: []string{"x"}, + wantErr: "axesSizes and axesNames must have the same length", + }, + { + name: "empty axesSizes", + shape: []int{}, + axisNames: []string{}, + wantErr: "DeviceMesh axesSizes cannot be empty", + }, + { + name: "empty axis name", + shape: []int{4}, + axisNames: []string{""}, + wantErr: "axis name at index 0 cannot be empty", + }, + { + name: "duplicate axis names", + shape: []int{2, 4}, + axisNames: []string{"x", "x"}, + wantErr: "axis name \"x\" is duplicated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) + require.Error(t, err) + assert.Nil(t, mesh) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } + }) + + t.Run("AxesNames", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) + require.NoError(t, err) + + axisNames := mesh.AxesNames() + assert.Equal(t, []string{"x", "y"}, axisNames) + + // Verify it returns a copy + axisNames[0] = "modified" + assert.Equal(t, []string{"x", "y"}, mesh.AxesNames()) + }) + + t.Run("Shape", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) + require.NoError(t, err) + + axesSizes := mesh.AxesSizes() + assert.Equal(t, []int{2, 4}, axesSizes) + + // Verify it returns a copy + axesSizes[0] = 99 + assert.Equal(t, []int{2, 4}, mesh.AxesSizes()) + }) + + t.Run("AxisSize", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) + require.NoError(t, err) + + tests := []struct { + name string + axisName string + wantSize int + wantErr bool + }{ + { + name: "valid axis x", + axisName: "x", + wantSize: 2, + wantErr: false, + }, + { + name: "valid axis y", + axisName: "y", + wantSize: 4, + wantErr: false, + }, + { + name: "non-existent axis", + axisName: "z", + wantSize: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + size, err := mesh.AxisSize(tt.axisName) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantSize, size) + } + }) + } + }) + + t.Run("String", func(t *testing.T) { + tests := []struct { + name string + shape []int + axisNames []string + want string + }{ + { + name: "1D mesh", + shape: []int{8}, + axisNames: []string{"replica"}, + want: "DeviceMesh(axesSizes={replica: 8})", + }, + { + name: "2D mesh", + shape: []int{2, 4}, + axisNames: []string{"x", "y"}, + want: "DeviceMesh(axesSizes={x: 2, y: 4})", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) + require.NoError(t, err) + assert.Equal(t, tt.want, mesh.String()) + }) + } + }) + + t.Run("SetDeviceAssignment_Valid", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + tests := []struct { + name string + devices []int + }{ + { + name: "sequential mapping", + devices: []int{0, 1, 2, 3}, + }, + { + name: "reverse mapping", + devices: []int{3, 2, 1, 0}, + }, + { + name: "custom mapping", + devices: []int{2, 1, 3, 0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := mesh.SetLogicalDeviceAssignment(tt.devices...) + require.NoErrorf(t, err, "failed test %q", tt.name) + }) + } + }) + + t.Run("SetDeviceAssignment_Errors", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + tests := []struct { + name string + devices []int + wantErr string + }{ + { + name: "wrong number of devices", + devices: []int{0, 1, 2}, + wantErr: "devices must have 4 elements", + }, + { + name: "duplicate device", + devices: []int{0, 1, 1, 3}, + wantErr: "physical device #1 is duplicated", + }, + { + name: "device out of range (negative)", + devices: []int{0, 1, -1, 3}, + wantErr: "devices must be between 0 and 3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := mesh.SetLogicalDeviceAssignment(tt.devices...) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } + }) + + t.Run("DeviceToMesh_2D", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) + require.NoError(t, err) + require.Equal(t, 8, mesh.NumDevices()) + }) + + t.Run("DeviceToMesh_3D", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) + require.NoError(t, err) + require.Equal(t, 8, mesh.NumDevices()) + }) + + t.Run("DeviceToMesh_WithCustomMapping", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + err = mesh.SetLogicalDeviceAssignment(3, 2, 1, 0) + require.NoError(t, err) + require.Equal(t, 4, mesh.NumDevices()) + err = mesh.SetLogicalDeviceAssignment(4, 2, 1, 0) + require.Error(t, err) + }) + + t.Run("ComputeReplicaGroups", func(t *testing.T) { + t.Run("2D mesh batch groups", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // Example from comments: m.ComputeReplicaGroups([]string{"batch"}) -> [][]int{{0, 2}, {1, 3}} + groups, err := mesh.ComputeReplicaGroups([]string{"batch"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 2}, {1, 3}}, groups) + }) + + t.Run("2D mesh data groups", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // Example from comments: m.ComputeReplicaGroups([]string{"data"}) -> [][]int{{0, 1}, {2, 3}} + groups, err := mesh.ComputeReplicaGroups([]string{"data"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 1}, {2, 3}}, groups) + }) + + t.Run("2D mesh global groups", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // Example from comments: m.ComputeReplicaGroups([]string{"batch", "data"}) -> [][]int{{0, 1, 2, 3}} + groups, err := mesh.ComputeReplicaGroups([]string{"batch", "data"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 1, 2, 3}}, groups) + }) + + t.Run("1D mesh", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + groups, err := mesh.ComputeReplicaGroups([]string{"replica"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 1, 2, 3}}, groups) + }) + + t.Run("3D mesh single axis", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) + require.NoError(t, err) + + // Groups along x axis: should split by y and z + groups, err := mesh.ComputeReplicaGroups([]string{"x"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 4}, {1, 5}, {2, 6}, {3, 7}}, groups) + }) + + t.Run("3D mesh two axes", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) + require.NoError(t, err) + + // Groups along x and y axes: should split by z + groups, err := mesh.ComputeReplicaGroups([]string{"x", "y"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 2, 4, 6}, {1, 3, 5, 7}}, groups) + }) + + t.Run("empty axes list", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // Empty axes list: each device is its own group + groups, err := mesh.ComputeReplicaGroups([]string{}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0}, {1}, {2}, {3}}, groups) + }) + + t.Run("non-existent axis", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // A non-existent axis should return an error. + _, err = mesh.ComputeReplicaGroups([]string{"nonexistent"}) + require.Error(t, err) + }) + }) +} diff --git a/types/shardy/shardingspec.go b/types/shardy/shardingspec.go new file mode 100644 index 0000000..8b9e5ff --- /dev/null +++ b/types/shardy/shardingspec.go @@ -0,0 +1,243 @@ +package shardy + +import ( + "fmt" + "sort" + "strings" + + "github.com/gomlx/stablehlo/types/shapes" + "github.com/pkg/errors" +) + +// ShardingSpec (also known as PartitionSpec in JAX) defines how a logical tensor is to be sharded (partitioned) across +// a DeviceMesh. This is used by Shardy, and is based on its documentation in [1]. +// +// The definition is per axis of the logical tensor -- and not per axis of the Mesh, a common confusion. +// If not all axes of the Tensor are defined, the tail axes are considered simply to be replicated across the whole +// mesh. +// +// Each tensor axis can be replicated or sharded across one or more mesh axes. +// +// Example: +// +// mesh := NewDeviceMesh("my_mesh", []int{2, 2}, []string{"data", "model"}) +// +// // Input's "batch" axis is sharded across the "data" axis of the mesh. +// inputSharding := NewShardingSpec(mesh).AddShardedAxis("data") +// +// // First axis is replicated, second is shared across "model" devices +// variableSharding := NewShardingSpec(mesh).AddReplicated().AddShardedAxis("model") +// +// // Second axis is sharded across both "data" and "model" devices. +// largeWeights := NewShardingSpec(mesh).AddReplicated().AddShardedAxis("data", "model") +// +// There are two advanced features supported but not tested (pls if you need let us know how it goes, or if you find +// any issues): +// +// 1. The tensor can also be sharded across mesh "sub-axes" -- seed detailed documentation in [1] +// 2. If using ShardingSpec for hints, instead of mesh axes one can give an "open" (in StableHLO marked as "?") +// axis, with the semantics that XLA Shardy can choose any mesh axis (or axes) to shard the tensor. See [1]. +// +// [1] https://github.com/openxla/shardy/blob/main/docs/sharding_representation.md +type ShardingSpec struct { + Mesh *DeviceMesh + Axes []TensorAxisSpec +} + +// TensorAxisSpec specifies how a tensor axis is to be sharded (or replicated). +// See details in ShardingSpec. +// +// Usually, one would create this using ShardingSpec.AddShardedAxis or ShardingSpec.AddReplicated +type TensorAxisSpec struct { + MeshAxes []MeshAxisSpec + Opened bool // If opened to further sharding. +} + +type MeshAxisSpec struct { + AxisName string + + // PreSize, Size are only set if defining a sub-axis of the mesh. + PreSize, Size int +} + +// NewShardingSpec creates a new ShardingSpec. +func NewShardingSpec(mesh *DeviceMesh) *ShardingSpec { + return &ShardingSpec{mesh, make([]TensorAxisSpec, 0)} +} + +// AddShardedAxis adds a new sharded axis to the ShardingSpec using one or more mesh axes. +// +// It returns itself, so calls can be chained. +func (s *ShardingSpec) AddShardedAxis(meshAxesNames ...string) *ShardingSpec { + axisSpec := TensorAxisSpec{} + for _, meshAxisName := range meshAxesNames { + axisSpec.MeshAxes = append(axisSpec.MeshAxes, MeshAxisSpec{AxisName: meshAxisName}) + } + s.Axes = append(s.Axes, axisSpec) + return s +} + +// AddReplicated adds a new replicated axis to the ShardingSpec. +// +// It returns itself, so calls can be chained. +func (s *ShardingSpec) AddReplicated() *ShardingSpec { + s.Axes = append(s.Axes, TensorAxisSpec{}) + return s +} + +// Rank returns the number of axes this ShardingSpec describes. +// +// Notice this may be smaller than the rank of the tensor using it: if a tensor axis is not defined in ShardingSpec, +// it is assumed to be replicated. +func (s *ShardingSpec) Rank() int { + return len(s.Axes) +} + +// IsReplicated returns true if the tensor is fully replicated +// (i.e., not sharded along any axis and not marked as "open"). +func (s *ShardingSpec) IsReplicated() bool { + for _, axisSpec := range s.Axes { + if axisSpec.MeshAxes != nil || axisSpec.Opened { + return false + } + } + return true +} + +// Validate checks that the ShardingSpec is valid for the given mesh. +func (s *ShardingSpec) Validate() error { + for i, axisSpec := range s.Axes { + for j, meshAxisSpec := range axisSpec.MeshAxes { + axisName := meshAxisSpec.AxisName + if axisName == "" { + return errors.Errorf( + "ShardingSpec tensor axis %d, mesh axis #%d refers to empty mesh axis name", i, j) + } + axisIdx, ok := s.Mesh.nameToAxis[axisName] + if !ok { + return errors.Errorf( + "ShardingSpec tensor axis %d, mesh axis #%d refers to unknown mesh axis %q", + i, j, axisName) + } + meshAxisSize := s.Mesh.axesSizes[axisIdx] + + // Check sub-axis specification. + if meshAxisSpec.Size > 0 { + if meshAxisSpec.PreSize <= 0 { + return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d %q has invalid PreSize %d", + i, j, axisName, meshAxisSpec.PreSize) + } + if meshAxisSize%(meshAxisSpec.PreSize*meshAxisSpec.Size) != 0 { + return errors.Errorf( + "ShardingSpec tensor axis %d, mesh axis #%d %q with PreSize %d and Size %d is not "+ + "compatible with mesh axis of size %d", + i, j, axisName, meshAxisSpec.PreSize, meshAxisSpec.Size, meshAxisSize) + } + } + } + } + return nil +} + +func (s *ShardingSpec) ValidateShape(shape shapes.Shape) error { + if s == nil { + // No sharding spec (nil) means fully replicated, and it's always valid for any shape. + return nil + } + err := s.Validate() + if err != nil { + return err + } + if s.Rank() > shape.Rank() { + return errors.Errorf("ShardingSpec shape rank %d is larger than tensor rank %d", s.Rank(), shape.Rank()) + } + return nil +} + +// ToStableHLO converts the ShardingSpec to its StableHLO string representation. +// See details in: +// https://github.com/openxla/shardy/blob/main/docs/sharding_representation.md +func (s *ShardingSpec) ToStableHLO() string { + var dimShardings []string + replicatedAxes := make(map[string]bool) + for _, axisName := range s.Mesh.axesNames { + replicatedAxes[axisName] = true + } + + for _, axisSpec := range s.Axes { + var hloAxes []string + for _, meshAxisSpec := range axisSpec.MeshAxes { + delete(replicatedAxes, meshAxisSpec.AxisName) + if meshAxisSpec.Size > 0 { + hloAxes = append(hloAxes, fmt.Sprintf("%s:(%d)%d", meshAxisSpec.AxisName, meshAxisSpec.PreSize, meshAxisSpec.Size)) + } else { + hloAxes = append(hloAxes, meshAxisSpec.AxisName) + } + } + if axisSpec.Opened { + hloAxes = append(hloAxes, "?") + } + dimShardings = append(dimShardings, fmt.Sprintf("{%s}", strings.Join(hloAxes, ", "))) + } + + var replicatedStrs []string + for axisName := range replicatedAxes { + replicatedStrs = append(replicatedStrs, axisName) + } + sort.Strings(replicatedStrs) + + replicatedPart := "" + if len(replicatedStrs) > 0 { + replicatedPart = fmt.Sprintf(", replicated={%s}", strings.Join(replicatedStrs, ", ")) + } + return fmt.Sprintf("#sdy.sharding<@%s, [%s]%s>", s.Mesh.Name(), strings.Join(dimShardings, ", "), replicatedPart) +} + +// ToValueAttribute converts the ShardingSpec to a StableHLO attribute of a value with the given shape. +// +// Notice the rank of the ShardingSpec may be smaller than the rank of the shape, in which case the extra axes are +// assumed to be replicated (empty). +// +// E.g.: "#sdy.sharding<@mesh, [{\"data\"}, {}]>" +func (s *ShardingSpec) ToValueAttribute(shape shapes.Shape) string { + var buf strings.Builder + w := func(format string, args ...any) { + buf.WriteString(fmt.Sprintf(format, args...)) + } + w("#sdy.sharding<@%s, [", s.Mesh.Name()) + for axisIdx := range shape.Rank() { + if axisIdx > 0 { + w(", ") + } + if axisIdx >= len(s.Axes) { + w("{}") + continue + } + tensorAxisSpec := s.Axes[axisIdx] + if len(tensorAxisSpec.MeshAxes) == 0 { + if tensorAxisSpec.Opened { + w("{?}") + } else { + w("{}") + } + continue + } + w("{") + for meshAxisIdx, meshAxisSpec := range tensorAxisSpec.MeshAxes { + if meshAxisIdx > 0 { + w(", ") + } + if meshAxisSpec.Size > 0 { + w("\"%s\":(%d)%d", meshAxisSpec.AxisName, meshAxisSpec.PreSize, meshAxisSpec.Size) + } else { + w("\"%s\"", meshAxisSpec.AxisName) + } + } + if tensorAxisSpec.Opened { + w(", ?") + } + w("}") + } + w("]>") + return buf.String() +} diff --git a/types/shardy/shardingspec_test.go b/types/shardy/shardingspec_test.go new file mode 100644 index 0000000..5e97424 --- /dev/null +++ b/types/shardy/shardingspec_test.go @@ -0,0 +1,116 @@ +package shardy + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestShardSpec_ToStableHLO(t *testing.T) { + mesh, err := NewDeviceMesh("test_mesh", []int{4, 2}, []string{"z", "a"}) + require.NoError(t, err) + testCases := []struct { + name string + spec *ShardingSpec + expected string + }{ + { + name: "Replicated", + spec: NewShardingSpec(mesh).AddReplicated(), + expected: "#sdy.sharding<@test_mesh, [{}], replicated={a, z}>", + }, + { + name: "Sharded", + spec: NewShardingSpec(mesh).AddShardedAxis("z"), + expected: "#sdy.sharding<@test_mesh, [{z}], replicated={a}>", + }, + { + name: "Sharded with multiple axes", + spec: NewShardingSpec(mesh).AddShardedAxis("z", "a"), + expected: "#sdy.sharding<@test_mesh, [{z, a}]>", + }, + { + name: "Sharded with sub-axis", + spec: &ShardingSpec{ + Mesh: mesh, + Axes: []TensorAxisSpec{ + {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 1, Size: 2}}}, + }, + }, + expected: "#sdy.sharding<@test_mesh, [{a:(1)2}], replicated={z}>", + }, + { + name: "Opened", + spec: &ShardingSpec{Mesh: mesh, Axes: []TensorAxisSpec{{Opened: true}}}, + expected: "#sdy.sharding<@test_mesh, [{?}], replicated={a, z}>", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expected, tc.spec.ToStableHLO()) + }) + } +} + +func TestShardSpec_Validate(t *testing.T) { + mesh, err := NewDeviceMesh("test_mesh", []int{2, 8}, []string{"z", "a"}) + require.NoError(t, err) + testCases := []struct { + name string + spec *ShardingSpec + expectError bool + }{ + { + name: "Valid sharding", + spec: NewShardingSpec(mesh).AddShardedAxis("z"), + expectError: false, + }, + { + name: "Unknown mesh axis", + spec: NewShardingSpec(mesh).AddShardedAxis("x"), + expectError: true, + }, + { + name: "Valid sub-axis", + spec: &ShardingSpec{ + Mesh: mesh, + Axes: []TensorAxisSpec{ + {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 2, Size: 4}}}, + }, + }, + expectError: false, + }, + { + name: "Invalid sub-axis (PreSize)", + spec: &ShardingSpec{ + Mesh: mesh, + Axes: []TensorAxisSpec{ + {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 0, Size: 4}}}, + }, + }, + expectError: true, + }, + { + name: "Invalid sub-axis (Size)", + spec: &ShardingSpec{ + Mesh: mesh, + Axes: []TensorAxisSpec{ + {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 2, Size: 5}}}, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.spec.Validate() + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/value.go b/value.go index c14dff8..5e48eee 100644 --- a/value.go +++ b/value.go @@ -18,9 +18,10 @@ import ( // // It also carries its shape information. type Value struct { - fn *Function - name string - shape shapes.Shape + fn *Function + name string + shape shapes.Shape + Attributes map[string]any } // Shape returns the shape of the value.