Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4fc293d
Added shardy types for DeviceMesh and ShardSpec.
janpfeifer Nov 18, 2025
574c673
feat(shardy): Implement ShardSpec.ToStableHLO and improve validation
google-labs-jules[bot] Nov 18, 2025
68f91ef
Moved NormalizeIdentifier to utils to break dependency cycle.
janpfeifer Nov 19, 2025
0802096
Added checks that mesh and mesh axes names are valid StableHLO identi…
janpfeifer Nov 19, 2025
1259abb
Added Builder.WithShardy() method.
janpfeifer Nov 19, 2025
ff58025
Merge branch 'shardy' of github.com:gomlx/stablehlo into shardy
janpfeifer Nov 19, 2025
64d506a
Renamed ShardSpec -> ShardingSpec
janpfeifer Nov 19, 2025
663f568
Adding optional ShardingSpec to the inputs and outputs of the functions.
janpfeifer Nov 19, 2025
cc61929
Fixed tests with correct stablehlo of shardingspec.
janpfeifer Nov 19, 2025
c858e73
feat: Add attribute support to function signatures
google-labs-jules[bot] Nov 19, 2025
11f22da
Added Shardy execution support.
janpfeifer Nov 19, 2025
0f0cb8c
Working Shardy tests.
janpfeifer Nov 19, 2025
9e85a69
Updated CHANGELOG.
janpfeifer Nov 19, 2025
4d6adcb
Renamed DeviceMesh.shape to DeviceMesh.axesSizes
janpfeifer Nov 20, 2025
4a90ce4
Cosmetic.
janpfeifer Nov 20, 2025
fb0abb6
Fixed documentation.
janpfeifer Nov 20, 2025
951eab5
DeviceMesh changed to take a "logical device assignment".
janpfeifer Nov 20, 2025
23e49cd
Added support for multiple meshes in StableHLO module.
janpfeifer Nov 20, 2025
1eb4942
Fixed test to use DeviceAssignment correctly.
janpfeifer Nov 20, 2025
83c1dac
Merge branch 'shardy' of github.com:gomlx/stablehlo into shardy
janpfeifer Nov 20, 2025
46f5115
Added support for logical device assignment. Updated tests.
janpfeifer Nov 20, 2025
42c313d
Reordered code, moving all distributed code to the end of builder.go.
janpfeifer Nov 20, 2025
b84e8cd
Fixed names of parameters
janpfeifer Nov 20, 2025
95c3ac0
s/DeviceMesh.Shape/DeviceMesh.AxesSizes
janpfeifer Nov 21, 2025
68ba747
Fixed AddShardedAxis.
janpfeifer Nov 21, 2025
f091953
ReturnWithShardingAndAttributes accept empty sharding specs list.
janpfeifer Nov 22, 2025
86595fb
ReturnWithAttributes accept nil attributes.
janpfeifer Nov 22, 2025
81ce241
Collective operations: don't set up channel_handle if not explicitly …
janpfeifer Nov 23, 2025
16b9a7c
Updated CHANGELOG.
janpfeifer Nov 23, 2025
e3439a2
Fixed attributes for outputs of functions.
janpfeifer Nov 23, 2025
a9e5561
Bumping to v0.2.0-rc0
janpfeifer Nov 25, 2025
95c070b
Fixed test: added missing parenthesis for outputs with attributes.
janpfeifer Nov 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 98 additions & 47 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"io"
"slices"

"github.com/gomlx/stablehlo/internal/utils"
"github.com/gomlx/stablehlo/types"
"github.com/gomlx/stablehlo/types/shardy"
"github.com/pkg/errors"
)

// Builder is used to construct a StableHLO program.
// Builder is used to construct a StableHLO program (or "Module")
// See details in New.
type Builder struct {
name string
Expand All @@ -22,38 +24,20 @@ type Builder struct {
// inlineUniqueID is a counter used to generate unique names for inlined functions values.
inlineUniqueID int

// NumReplicas is the number of replicas for data parallelism.
NumReplicas int
// NumPartitions is the number of partitions for model parallelism.
NumPartitions int
// meshes used for Shardy.
meshes []*shardy.DeviceMesh

// numReplicas is the number of replicas for data parallelism.
numReplicas int

// numPartitions is the number of partitions for model parallelism.
numPartitions int

// nextChannelID is the next ID to be assigned in channel handles.
// It is just a Unique ID.
nextChannelID int
}

// NormalizeIdentifier converts the name of an identifier (function name or function input parameter
// name) to a valid one: only letters, digits and underscores are allowed.
//
// Invalid characters are replaced with underscores.
// If the name starts with a digit, it is prefixed with an underscore.
//
// The name is normalized in place.
func NormalizeIdentifier(name string) string {
result := make([]rune, 0, len(name)+1)
if name[0] >= '0' && name[0] <= '9' {
result = append(result, '_')
}
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' {
result = append(result, r)
} else {
result = append(result, '_')
}
}
return string(result)
}

// New creates a new Builder object holding a computation graph in construction.
//
// From a builder you can create functions.
Expand All @@ -62,7 +46,7 @@ func NormalizeIdentifier(name string) string {
// You have to define the "main" function for your StableHLO program: you can use Builder.Main to do so, or
// Builder.NewFunction("main",...), it's the same.
//
// Once you are all set, call Builder.Build and it will return the StableHLO program as a []byte that can
// Once you are all set, call Builder.Build and it will return the StableHLO program (or "Module") as a []byte that can
// be used with PJRT.
//
// See github.com/gomlx/gopjrt for a Go API to PJRT.
Expand All @@ -72,20 +56,6 @@ func New(name string) *Builder {
}
}

// WithNumReplicas sets the number of replicas (for data parallelism).
// This is added as an attribute to the StableHLO module.
func (b *Builder) WithNumReplicas(n int) *Builder {
b.NumReplicas = n
return b
}

// WithNumPartitions sets the number of partitions (for model parallelism).
// This is added as an attribute to the StableHLO module.
func (b *Builder) WithNumPartitions(n int) *Builder {
b.NumPartitions = n
return b
}

// elementWriter represents elements of ToStableHLO that know how to write themselves.
type elementWriter interface {
Write(w io.Writer, indentation string) error
Expand Down Expand Up @@ -134,11 +104,11 @@ const IndentationStep = " "
// getModuleAttributes returns the attributes for the StableHLO module (StableHLO code) generated.
func (b *Builder) getModuleAttributes() []string {
var attributes []string
if b.NumReplicas > 0 {
attributes = append(attributes, fmt.Sprintf("stablehlo.num_replicas = %d", b.NumReplicas))
if b.numReplicas > 0 {
attributes = append(attributes, fmt.Sprintf("stablehlo.num_replicas = %d", b.numReplicas))
}
if b.NumPartitions > 0 {
attributes = append(attributes, fmt.Sprintf(" stablehlo.num_partitions = %d", b.NumPartitions))
if b.numPartitions > 0 {
attributes = append(attributes, fmt.Sprintf(" stablehlo.num_partitions = %d", b.numPartitions))
}
return attributes
}
Expand Down Expand Up @@ -177,10 +147,22 @@ func (b *Builder) Write(writer io.Writer) error {
}
w("%s", attr)
}
w(" }")
w("}")
}
w(" {\n")

// Write Shardy meshes if needed:
if len(b.meshes) > 0 {
namesUsed := utils.MakeSet[string](len(b.meshes))
for _, mesh := range b.meshes {
if namesUsed.Has(mesh.Name()) {
return errors.Errorf("duplicate mesh name %q", mesh.Name())
}
namesUsed.Insert(mesh.Name())
w("%s%s\n", IndentationStep, mesh.ToStableHLO())
}
}

// Write non-inline functions:
var count int
for _, fn := range b.functions {
Expand Down Expand Up @@ -248,3 +230,72 @@ func (b *Builder) getChannelHandle(config *types.CollectiveConfig) literalStr {

return literalStrF("#stablehlo.channel_handle<handle = %d, type = %d>", id, typ)
}

// WithNumReplicas sets the number of replicas (for data parallelism).
// This is added as an attribute to the StableHLO module.
//
// Consider using WithShardy for distributed computation instead: other forms of distributed
// (collective) computation across devices are not tested and may not work.
func (b *Builder) WithNumReplicas(n int) *Builder {
b.numReplicas = n
return b
}

// WithNumPartitions sets the number of partitions (for model parallelism).
// This is added as an attribute to the StableHLO module.
//
// Consider using WithShardy for distributed computation instead: other forms of distributed
// (collective) computation across devices are not tested and may not work.
func (b *Builder) WithNumPartitions(n int) *Builder {
b.numPartitions = n
return b
}

// WithShardy enables distributed computation across the devices selected by the given meshes.
//
// This is the recommended way to do distributed (across devices) computation, and given the inputs
// with sharded information, Shardy will automatically distribute the computation, without you needing
// to specify any of the collective operations.
//
// Usually, there is only one meshes. But one can split the devices in different meshes. The meshes overlap
// the concrete devices used.
//
// See details of XLA Shardy in [1]
//
// [1] https://github.com/openxla/shardy
func (b *Builder) WithShardy(meshes ...*shardy.DeviceMesh) *Builder {
b.meshes = meshes
b.WithNumReplicas(1)
numDevices := 0
for _, mesh := range meshes {
numDevices = max(numDevices, mesh.NumDevices())
}
b.WithNumPartitions(numDevices)
return b
}

// Meshes returns the meshes configured with WithShardy.
func (b *Builder) Meshes() []*shardy.DeviceMesh {
return b.meshes
}

// NewShardingSpec creates a new ShardingSpec using the first mesh configured with WithShardy.
// It returns nil if no mesh was not configured.
//
// This is a shortcut to NewShardingSpecByMeshIx(0).
func (b *Builder) NewShardingSpec() *shardy.ShardingSpec {
if len(b.meshes) == 0 {
return nil
}
return shardy.NewShardingSpec(b.meshes[0])
}

// NewShardingSpecByMeshIx creates a new ShardingSpec for the meshIdx (the order given by WithShardy).
//
// It may return nil if meshIdx is out of range.
func (b *Builder) NewShardingSpecByMeshIx(meshIdx int) *shardy.ShardingSpec {
if meshIdx < 0 || meshIdx >= len(b.meshes) {
return nil
}
return shardy.NewShardingSpec(b.meshes[meshIdx])
}
42 changes: 35 additions & 7 deletions collective.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ func formatReplicaGroups(groups [][]int) literalStr {
// notice it's not the device numbers by the replica numbers (there is an indirection).
// Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device
// numbers. E.g., `[[0, 1, 2, 3]]`.
// - config: Optional configuration of the channels to be used. This is not needed for SPMD programs.
// - config: Optional configuration of the channels to be used. This is shouldn't be used for SPMD programs.
//
// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed
// (collective) computation across devices are not tested and may not work.
func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types.CollectiveConfig) (*Value, error) {
op := optypes.CollectiveBroadcast
fn := operand.fn
Expand Down Expand Up @@ -77,7 +80,9 @@ func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types
stmt := fn.addOp(op, outputShape, operand)
stmt.Attributes = map[string]any{
"replica_groups": formatReplicaGroups(replicaGroups),
"channel_handle": fn.Builder.getChannelHandle(cfg),
}
if cfg != nil {
stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg)
}
return stmt.Outputs[0], nil
}
Expand All @@ -95,6 +100,9 @@ func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types
// Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device
// numbers. E.g., `[[0, 1, 2, 3]]`.
// - config: Optional configuration of the channels to be used. This is not needed for SPMD programs.
//
// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed
// (collective) computation across devices are not tested and may not work.
func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, config ...*types.CollectiveConfig) (
[]*Value, error) {
op := optypes.AllReduce
Expand Down Expand Up @@ -122,7 +130,7 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function,
outputShapes, err := shapeinference.AllReduce(
valuesToShapes(operands),
valuesToShapes(computation.Inputs),
computation.Outputs,
valuesToShapes(computation.Outputs),
replicaGroups)
if err != nil {
return nil, err
Expand All @@ -138,7 +146,9 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function,
stmt := fn.addMultiOp(op, outputShapes, operands)
stmt.Attributes = map[string]any{
"replica_groups": formatReplicaGroups(replicaGroups),
"channel_handle": fn.Builder.getChannelHandle(cfg),
}
if cfg != nil {
stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg)
}
if cfg != nil && cfg.UseGlobalDeviceIDs {
stmt.Attributes["use_global_device_ids"] = true
Expand All @@ -153,6 +163,9 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function,
// - replicaGroups: A 2D array defining the communicating device groups.
// - allGatherDim: The dimension along which to concatenate the operands.
// - config: Optional configuration of the channels to be used.
//
// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed
// (collective) computation across devices are not tested and may not work.
func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config ...*types.CollectiveConfig) (*Value, error) {
op := optypes.AllGather
fn := operand.fn
Expand All @@ -176,7 +189,9 @@ func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config .
stmt.Attributes = map[string]any{
"replica_groups": formatReplicaGroups(replicaGroups),
"all_gather_dim": int64(allGatherDim),
"channel_handle": fn.Builder.getChannelHandle(cfg),
}
if cfg != nil {
stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg)
}
if cfg != nil && cfg.UseGlobalDeviceIDs {
stmt.Attributes["use_global_device_ids"] = true
Expand All @@ -193,6 +208,9 @@ func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config .
// - concatDimension: The dimension along which to concatenate the received chunks.
// - splitCount: The number of chunks to split the operand into. This must match the size of the replica groups.
// - config: Optional configuration of the channels to be used.
//
// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed
// (collective) computation across devices are not tested and may not work.
func AllToAll(operand *Value, replicaGroups [][]int, splitDimension, concatDimension, splitCount int, config ...*types.CollectiveConfig) (*Value, error) {
op := optypes.AllToAll
fn := operand.fn
Expand All @@ -218,7 +236,9 @@ func AllToAll(operand *Value, replicaGroups [][]int, splitDimension, concatDimen
"split_dimension": int64(splitDimension),
"concat_dimension": int64(concatDimension),
"split_count": int64(splitCount),
"channel_handle": fn.Builder.getChannelHandle(cfg),
}
if cfg != nil {
stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg)
}
if cfg != nil && cfg.UseGlobalDeviceIDs {
stmt.Attributes["use_global_device_ids"] = true
Expand Down Expand Up @@ -251,6 +271,9 @@ func formatSourceTargetPairs(pairs [][2]int) literalStr {
// - operand: The tensor from the *local* replica.
// - sourceTargetPairs: A 2D array where each inner array is a `[source, target]` pair of replica IDs.
// - config: Optional configuration of the channels to be used.
//
// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed
// (collective) computation across devices are not tested and may not work.
func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*types.CollectiveConfig) (*Value, error) {
op := optypes.CollectivePermute
fn := operand.fn
Expand All @@ -273,7 +296,12 @@ func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*ty
stmt := fn.addOp(op, outputShape, operand)
stmt.Attributes = map[string]any{
"source_target_pairs": formatSourceTargetPairs(sourceTargetPairs),
"channel_handle": fn.Builder.getChannelHandle(cfg),
}
if cfg != nil {
stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg)
}
if cfg != nil && cfg.UseGlobalDeviceIDs {
stmt.Attributes["use_global_device_ids"] = true
}
return stmt.Outputs[0], nil
}
12 changes: 10 additions & 2 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# v0.2.0: ...
# v0.2.0: (Release Candidate) Adding support for XLA Shardy

- `Function.Input` and `Function.NamedInput`: (change in API) they now may return an error, if the name is duplicate.
- `AllReduce` now supports arbitrary number of inputs, to be reduced at once.
- Collective ops:
- `AllReduce` now supports arbitrary number of inputs, to be reduced at once.
- `channel_handle` attribute is not generated by default: it is not compatible with SPMD, and probably will
only be useful once `Send`/`Recv` are supported.
- Added XLA Shardy support:
- Added `shardy.DeviceMesh` and `shardy.ShardingSpec` types.
- Added `Builder.WithShardy(mesh)`
- Added `Function.NamedInputWithShardingAndAttributes()`
- Added `Function.ReturnWithShardingAndAttributes()`

# v0.1.0: 2025/11/06 Multi-Device support

Expand Down
Loading