From 4fc293ddc4925b43a38066bc2a8ec82a4d800d2c Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Tue, 18 Nov 2025 09:41:54 +0100 Subject: [PATCH 01/30] Added shardy types for DeviceMesh and ShardSpec. --- types/shardy/devicemesh.go | 271 ++++++++++++++++++ types/shardy/devicemesh_test.go | 468 ++++++++++++++++++++++++++++++++ types/shardy/shardspec.go | 117 ++++++++ 3 files changed, 856 insertions(+) create mode 100644 types/shardy/devicemesh.go create mode 100644 types/shardy/devicemesh_test.go create mode 100644 types/shardy/shardspec.go diff --git a/types/shardy/devicemesh.go b/types/shardy/devicemesh.go new file mode 100644 index 0000000..7b1a53b --- /dev/null +++ b/types/shardy/devicemesh.go @@ -0,0 +1,271 @@ +// Package shardy provides the types needed to define a distributed computation topology. +// This is used for leveraging XLA Shardy [1] +// +// [1] https://github.com/openxla/shardy +package shardy + +import ( + "fmt" + "slices" + "strings" + + "github.com/gomlx/stablehlo/internal/utils" + "github.com/pkg/errors" +) + +// DeviceMesh defines the logical topology of a set of devices on a backend. +type DeviceMesh struct { + name string + + // axesNames are the names of the mesh axes. + axesNames []string + + // shape defines the number of devices along each mesh axis. + shape []int + + // nameToAxis maps axis names to their index. + nameToAxis map[string]int + + // numDevices is the total number of devices in the mesh. + numDevices int + + // deviceAssignment is the list of devices numbers in the mesh, in the order they appear in the mesh. + deviceAssignment []int + + // physicalDeviceMapping is the mapping of concrete devices to the flat index in the mesh. + physicalDeviceMapping map[int]int +} + +// NewDeviceMesh creates a new logical topology of a set of devices. +// +// - shape: defines the number of devices along each mesh axis, one value per axis. +// - axesNames: the names of the mesh axes. One value per axis. +// +// The default mapping of concrete devices numbers to the mesh is sequential, starting from 0, but it can be +// changed with the DeviceMesh.SetDeviceAssignment() method. +// +// For non-symmetric devices, where connection speed among the devices matter, a custom mapping can be provided +// with the DeviceMesh.WithDeviceMapping() method. +func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, error) { + if len(shape) != len(axisNames) { + return nil, errors.Errorf("shape and axesNames must have the same length, got %d and %d", + len(shape), len(axisNames)) + } + if len(shape) == 0 { + return nil, errors.New("DeviceMesh shape cannot be empty") + } + + numDevices := 1 + nameToAxis := make(map[string]int, len(shape)) + for i, name := range axisNames { + if name == "" { + return nil, errors.Errorf("DeviceMesh axis name at index %d cannot be empty", i) + } + if _, found := nameToAxis[name]; found { + return nil, errors.Errorf("DeviceMesh axis name %q is duplicated", name) + } + nameToAxis[name] = i + numDevices *= shape[i] + } + + m := &DeviceMesh{ + name: name, + axesNames: axisNames, + shape: shape, + nameToAxis: nameToAxis, + numDevices: numDevices, + deviceAssignment: make([]int, numDevices), + } + for i := range m.deviceAssignment { + m.deviceAssignment[i] = i + } + m.buildPhysicalDeviceMapping() + return m, nil +} + +func (m *DeviceMesh) Name() string { + return m.name +} + +func (m *DeviceMesh) buildPhysicalDeviceMapping() { + m.physicalDeviceMapping = make(map[int]int, m.numDevices) + for i, device := range m.deviceAssignment { + m.physicalDeviceMapping[device] = i + } +} + +// NumDevices returns the total number of devices in the mesh. +func (m *DeviceMesh) NumDevices() int { + return m.numDevices +} + +// Rank returns the number of axes in the mesh. +func (m *DeviceMesh) Rank() int { + return len(m.shape) +} + +// AxisNames returns a copy of the mesh's axis names. +func (m *DeviceMesh) AxisNames() []string { + return slices.Clone(m.axesNames) +} + +// Shape returns a copy of the mesh's shape. +func (m *DeviceMesh) Shape() []int { + shape := make([]int, len(m.shape)) + copy(shape, m.shape) + return shape +} + +// AxisSize returns the number of devices along the given mesh axis. +func (m *DeviceMesh) AxisSize(axisName string) (int, error) { + idx, found := m.nameToAxis[axisName] + if !found { + return 0, errors.Errorf("mesh axis %q not found", axisName) + } + return m.shape[idx], nil +} + +// String implements the fmt.Stringer interface. +func (m *DeviceMesh) String() string { + var sb strings.Builder + sb.WriteString("DeviceMesh(shape={") + for i, name := range m.axesNames { + if i > 0 { + sb.WriteString(", ") + } + _, _ = fmt.Fprintf(&sb, "%s: %d", name, m.shape[i]) + } + sb.WriteString("})") + return sb.String() +} + +// SetDeviceAssignment sets the assignment of concrete devices to the mesh. +// +// It returns an error if deviceAssignment has invalid device numbers or len(devices) != NumDevices(). +func (m *DeviceMesh) SetDeviceAssignment(devices ...int) error { + if len(devices) != m.numDevices { + return errors.Errorf("devices must have %d elements, got %d", m.numDevices, len(devices)) + } + seen := utils.MakeSet[int](m.numDevices) + for _, device := range devices { + if seen.Has(device) { + return errors.Errorf("physical device #%d is duplicated in mapping", device) + } + seen.Insert(device) + if device < 0 { + return errors.Errorf("devices must be positive, got device %d", device) + } + } + copy(m.deviceAssignment, devices) + m.buildPhysicalDeviceMapping() + if len(m.physicalDeviceMapping) != m.numDevices { + return errors.Errorf("provided devicesIn: physicalDeviceMapping has %d elements, expected %d", len(m.physicalDeviceMapping), m.numDevices) + } + return nil +} + +// DeviceAssignment returns the list of devices in the mesh, in the order they appear in the mesh. +func (m *DeviceMesh) DeviceAssignment() []int { + return slices.Clone(m.deviceAssignment) +} + +// DeviceToMesh return the indices (flat and per-axis) assigned to the given physicalDevice. +func (m *DeviceMesh) DeviceToMesh(physicalDevice int) (flatIdx int, axisIndices []int, err error) { + var ok bool + flatIdx, ok = m.physicalDeviceMapping[physicalDevice] + if !ok { + return 0, nil, errors.Errorf("physical device %d is not part of the mesh", physicalDevice) + } + + // Convert flat index to per-axis indices + axisIndices = make([]int, len(m.shape)) + remaining := flatIdx + for i := len(m.shape) - 1; i >= 0; i-- { + axisIndices[i] = remaining % m.shape[i] + remaining /= m.shape[i] + } + return flatIdx, axisIndices, nil +} + +// ComputeReplicaGroups returns the replica groups participating in some collective (distributed) operation given the +// axes along which the operation is performed. +// +// Each replica group (a []int) includes the device indices (from the DeviceAssignment) for the axes specified. +// The other axes will be split into different replica groups. +// +// Example: +// +// m := NewDeviceMesh([]int{2, 2}, []string{"batch", "data"}) +// batchGroups, _ := m.ComputeReplicaGroups([]string{"batch"}) // -> [][]int{{0, 2}, {1, 3}} +// dataGroups, _ := m.ComputeReplicaGroups([]string{"data"}) // -> [][]int{{0, 1}, {2, 3}} +// globalGroups, _ := m.ComputeReplicaGroups([]string{"batch", "data"}) // -> [][]int{{0, 1, 2, 3}} +func (m *DeviceMesh) ComputeReplicaGroups(axes []string) ([][]int, error) { + // Find indices of the specified axes + axisIndices := make([]int, 0, len(axes)) + axisSet := utils.MakeSet[int](len(axes)) + for _, axis := range axes { + if idx, found := m.nameToAxis[axis]; found { + if axisSet.Has(idx) { + return nil, errors.Errorf("axis %q is duplicated: each axis can only appear once", axis) + } + axisIndices = append(axisIndices, idx) + axisSet.Insert(idx) + } else { + return nil, errors.Errorf("axis %q not found in mesh", axis) + } + } + + // Create indices for each axis dimension + nonAxisIndices := make([]int, 0, len(m.shape)-len(axisIndices)) + for i := range m.shape { + if !slices.Contains(axisIndices, i) { + nonAxisIndices = append(nonAxisIndices, i) + } + } + + // Calculate the size of each group and number of groups + groupSize := 1 + for _, idx := range axisIndices { + groupSize *= m.shape[idx] + } + numGroups := m.numDevices / groupSize + + // Initialize the result + groups := make([][]int, numGroups) + for i := range groups { + groups[i] = make([]int, groupSize) + } + + // Fill in the groups + for flatIdx := 0; flatIdx < m.numDevices; flatIdx++ { + // Convert flat index to per-axis indices + indices := make([]int, len(m.shape)) + remaining := flatIdx + for i := len(m.shape) - 1; i >= 0; i-- { + indices[i] = remaining % m.shape[i] + remaining /= m.shape[i] + } + + // Calculate group index from non-axis indices + groupIdx := 0 + multiplier := 1 + for i := len(nonAxisIndices) - 1; i >= 0; i-- { + axisIdx := nonAxisIndices[i] + groupIdx += indices[axisIdx] * multiplier + multiplier *= m.shape[axisIdx] + } + + // Calculate position within group from axis indices + posInGroup := 0 + multiplier = 1 + for i := len(axisIndices) - 1; i >= 0; i-- { + axisIdx := axisIndices[i] + posInGroup += indices[axisIdx] * multiplier + multiplier *= m.shape[axisIdx] + } + + groups[groupIdx][posInGroup] = flatIdx + } + + return groups, nil +} diff --git a/types/shardy/devicemesh_test.go b/types/shardy/devicemesh_test.go new file mode 100644 index 0000000..e09be45 --- /dev/null +++ b/types/shardy/devicemesh_test.go @@ -0,0 +1,468 @@ +package shardy_test + +import ( + "testing" + + "github.com/gomlx/stablehlo/types/shardy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDeviceMesh(t *testing.T) { + t.Run("NewDeviceMesh_Valid", func(t *testing.T) { + tests := []struct { + name string + shape []int + axisNames []string + wantRank int + wantNum int + }{ + { + name: "1D mesh", + shape: []int{8}, + axisNames: []string{"replica"}, + wantRank: 1, + wantNum: 8, + }, + { + name: "2D mesh", + shape: []int{2, 4}, + axisNames: []string{"x", "y"}, + wantRank: 2, + wantNum: 8, + }, + { + name: "3D mesh", + shape: []int{2, 2, 2}, + axisNames: []string{"x", "y", "z"}, + wantRank: 3, + wantNum: 8, + }, + { + name: "single device", + shape: []int{1}, + axisNames: []string{"replica"}, + wantRank: 1, + wantNum: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) + require.NoError(t, err) + assert.NotNil(t, mesh) + assert.Equal(t, tt.wantRank, mesh.Rank()) + assert.Equal(t, tt.wantNum, mesh.NumDevices()) + }) + } + }) + + t.Run("NewDeviceMesh_Errors", func(t *testing.T) { + tests := []struct { + name string + shape []int + axisNames []string + wantErr string + }{ + { + name: "mismatched lengths", + shape: []int{2, 4}, + axisNames: []string{"x"}, + wantErr: "shape and axesNames must have the same length", + }, + { + name: "empty shape", + shape: []int{}, + axisNames: []string{}, + wantErr: "DeviceMesh shape cannot be empty", + }, + { + name: "empty axis name", + shape: []int{4}, + axisNames: []string{""}, + wantErr: "axis name at index 0 cannot be empty", + }, + { + name: "duplicate axis names", + shape: []int{2, 4}, + axisNames: []string{"x", "x"}, + wantErr: "axis name \"x\" is duplicated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) + require.Error(t, err) + assert.Nil(t, mesh) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } + }) + + t.Run("AxisNames", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) + require.NoError(t, err) + + axisNames := mesh.AxisNames() + assert.Equal(t, []string{"x", "y"}, axisNames) + + // Verify it returns a copy + axisNames[0] = "modified" + assert.Equal(t, []string{"x", "y"}, mesh.AxisNames()) + }) + + t.Run("Shape", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) + require.NoError(t, err) + + shape := mesh.Shape() + assert.Equal(t, []int{2, 4}, shape) + + // Verify it returns a copy + shape[0] = 99 + assert.Equal(t, []int{2, 4}, mesh.Shape()) + }) + + t.Run("AxisSize", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) + require.NoError(t, err) + + tests := []struct { + name string + axisName string + wantSize int + wantErr bool + }{ + { + name: "valid axis x", + axisName: "x", + wantSize: 2, + wantErr: false, + }, + { + name: "valid axis y", + axisName: "y", + wantSize: 4, + wantErr: false, + }, + { + name: "non-existent axis", + axisName: "z", + wantSize: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + size, err := mesh.AxisSize(tt.axisName) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantSize, size) + } + }) + } + }) + + t.Run("String", func(t *testing.T) { + tests := []struct { + name string + shape []int + axisNames []string + want string + }{ + { + name: "1D mesh", + shape: []int{8}, + axisNames: []string{"replica"}, + want: "DeviceMesh(shape={replica: 8})", + }, + { + name: "2D mesh", + shape: []int{2, 4}, + axisNames: []string{"x", "y"}, + want: "DeviceMesh(shape={x: 2, y: 4})", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) + require.NoError(t, err) + assert.Equal(t, tt.want, mesh.String()) + }) + } + }) + + t.Run("SetDeviceAssignment_Valid", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + tests := []struct { + name string + devices []int + }{ + { + name: "sequential mapping", + devices: []int{0, 1, 2, 3}, + }, + { + name: "reverse mapping", + devices: []int{3, 2, 1, 0}, + }, + { + name: "custom mapping", + devices: []int{2, 5, 1, 7}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := mesh.SetDeviceAssignment(tt.devices...) + require.NoError(t, err) + + // Verify mapping is applied correctly + for i, device := range tt.devices { + flatIdx, axisIndices, err := mesh.DeviceToMesh(device) + require.NoError(t, err) + assert.Equal(t, i, flatIdx) + assert.Equal(t, []int{i}, axisIndices) + } + }) + } + }) + + t.Run("SetDeviceAssignment_Errors", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + tests := []struct { + name string + devices []int + wantErr string + }{ + { + name: "wrong number of devices", + devices: []int{0, 1, 2}, + wantErr: "devices must have 4 elements", + }, + { + name: "duplicate device", + devices: []int{0, 1, 1, 3}, + wantErr: "physical device #1 is duplicated", + }, + { + name: "device out of range (negative)", + devices: []int{0, 1, -1, 3}, + wantErr: "devices must be positive", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := mesh.SetDeviceAssignment(tt.devices...) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } + }) + + t.Run("DeviceToMesh_1D", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + for i := 0; i < 4; i++ { + flatIdx, axisIndices, err := mesh.DeviceToMesh(int(i)) + require.NoError(t, err) + assert.Equal(t, i, flatIdx) + assert.Equal(t, []int{i}, axisIndices) + } + }) + + t.Run("DeviceToMesh_2D", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) + require.NoError(t, err) + + tests := []struct { + device int + wantFlat int + wantIndices []int + }{ + {device: 0, wantFlat: 0, wantIndices: []int{0, 0}}, + {device: 1, wantFlat: 1, wantIndices: []int{0, 1}}, + {device: 2, wantFlat: 2, wantIndices: []int{0, 2}}, + {device: 3, wantFlat: 3, wantIndices: []int{0, 3}}, + {device: 4, wantFlat: 4, wantIndices: []int{1, 0}}, + {device: 5, wantFlat: 5, wantIndices: []int{1, 1}}, + {device: 6, wantFlat: 6, wantIndices: []int{1, 2}}, + {device: 7, wantFlat: 7, wantIndices: []int{1, 3}}, + } + + for _, tt := range tests { + t.Run(string(rune(tt.device)), func(t *testing.T) { + flatIdx, axisIndices, err := mesh.DeviceToMesh(tt.device) + require.NoError(t, err) + assert.Equal(t, tt.wantFlat, flatIdx) + assert.Equal(t, tt.wantIndices, axisIndices) + }) + } + }) + + t.Run("DeviceToMesh_3D", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) + require.NoError(t, err) + + tests := []struct { + device int + wantFlat int + wantIndices []int + }{ + {device: 0, wantFlat: 0, wantIndices: []int{0, 0, 0}}, + {device: 1, wantFlat: 1, wantIndices: []int{0, 0, 1}}, + {device: 2, wantFlat: 2, wantIndices: []int{0, 1, 0}}, + {device: 3, wantFlat: 3, wantIndices: []int{0, 1, 1}}, + {device: 4, wantFlat: 4, wantIndices: []int{1, 0, 0}}, + {device: 5, wantFlat: 5, wantIndices: []int{1, 0, 1}}, + {device: 6, wantFlat: 6, wantIndices: []int{1, 1, 0}}, + {device: 7, wantFlat: 7, wantIndices: []int{1, 1, 1}}, + } + + for _, tt := range tests { + t.Run(string(rune(tt.device)), func(t *testing.T) { + flatIdx, axisIndices, err := mesh.DeviceToMesh(tt.device) + require.NoError(t, err) + assert.Equal(t, tt.wantFlat, flatIdx) + assert.Equal(t, tt.wantIndices, axisIndices) + }) + } + }) + + t.Run("DeviceToMesh_WithCustomMapping", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + // Set custom mapping: devices [7, 5, 3, 1] + err = mesh.SetDeviceAssignment(7, 5, 3, 1) + require.NoError(t, err) + + tests := []struct { + device int + wantFlat int + wantIndices []int + }{ + {device: 7, wantFlat: 0, wantIndices: []int{0}}, + {device: 5, wantFlat: 1, wantIndices: []int{1}}, + {device: 3, wantFlat: 2, wantIndices: []int{2}}, + {device: 1, wantFlat: 3, wantIndices: []int{3}}, + } + + for _, tt := range tests { + t.Run(string(rune(tt.device)), func(t *testing.T) { + flatIdx, axisIndices, err := mesh.DeviceToMesh(tt.device) + require.NoError(t, err) + assert.Equal(t, tt.wantFlat, flatIdx) + assert.Equal(t, tt.wantIndices, axisIndices) + }) + } + + // Devices not in the mesh should error + _, _, err = mesh.DeviceToMesh(0) + require.Error(t, err) + assert.Contains(t, err.Error(), "not part of the mesh") + }) + + t.Run("DeviceToMesh_NotInMesh", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + // Device 5 is not in the mesh (only 0-3 are used) + _, _, err = mesh.DeviceToMesh(5) + require.Error(t, err) + assert.Contains(t, err.Error(), "physical device 5 is not part of the mesh") + }) + + t.Run("ComputeReplicaGroups", func(t *testing.T) { + t.Run("2D mesh batch groups", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // Example from comments: m.ComputeReplicaGroups([]string{"batch"}) -> [][]int{{0, 2}, {1, 3}} + groups, err := mesh.ComputeReplicaGroups([]string{"batch"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 2}, {1, 3}}, groups) + }) + + t.Run("2D mesh data groups", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // Example from comments: m.ComputeReplicaGroups([]string{"data"}) -> [][]int{{0, 1}, {2, 3}} + groups, err := mesh.ComputeReplicaGroups([]string{"data"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 1}, {2, 3}}, groups) + }) + + t.Run("2D mesh global groups", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // Example from comments: m.ComputeReplicaGroups([]string{"batch", "data"}) -> [][]int{{0, 1, 2, 3}} + groups, err := mesh.ComputeReplicaGroups([]string{"batch", "data"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 1, 2, 3}}, groups) + }) + + t.Run("1D mesh", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) + require.NoError(t, err) + + groups, err := mesh.ComputeReplicaGroups([]string{"replica"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 1, 2, 3}}, groups) + }) + + t.Run("3D mesh single axis", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) + require.NoError(t, err) + + // Groups along x axis: should split by y and z + groups, err := mesh.ComputeReplicaGroups([]string{"x"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 4}, {1, 5}, {2, 6}, {3, 7}}, groups) + }) + + t.Run("3D mesh two axes", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) + require.NoError(t, err) + + // Groups along x and y axes: should split by z + groups, err := mesh.ComputeReplicaGroups([]string{"x", "y"}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0, 2, 4, 6}, {1, 3, 5, 7}}, groups) + }) + + t.Run("empty axes list", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // Empty axes list: each device is its own group + groups, err := mesh.ComputeReplicaGroups([]string{}) + require.NoError(t, err) + assert.Equal(t, [][]int{{0}, {1}, {2}, {3}}, groups) + }) + + t.Run("non-existent axis", func(t *testing.T) { + mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) + require.NoError(t, err) + + // A non-existent axis should return an error. + _, err = mesh.ComputeReplicaGroups([]string{"nonexistent"}) + require.Error(t, err) + }) + }) +} diff --git a/types/shardy/shardspec.go b/types/shardy/shardspec.go new file mode 100644 index 0000000..1ae7cc6 --- /dev/null +++ b/types/shardy/shardspec.go @@ -0,0 +1,117 @@ +package shardy + +import ( + "github.com/pkg/errors" +) + +// ShardSpec (also known as PartitionSpec in JAX) defines how a logical tensor is to be sharded (partitioned) across +// a DeviceMesh. This is used by Shardy, and is based on its documentation in [1]. +// +// The definition is per axis of the logical tensor -- and not per axis of the Mesh, common confusion. +// If not all axes of the Tensor are defined, the tail axes are considered simply to be replicated across the whole +// mesh. +// +// Each tensor axis can be replicated or sharded across one or more mesh axes. +// +// Example: +// +// mesh := NewDeviceMesh("my_mesh", []int{2, 2}, []string{"data", "model"}) +// +// // Input's "batch" axis is sharded across the "data" axis of the mesh. +// inputSharding := MakeShardSpec(mesh.Name()).AddShardedAxis("data") +// +// // First axis is replicated, second is shared across "model" devices +// variableSharding := MakeShardSpec(mesh.Name()).AddReplicated().AddShardedAxis("model") +// +// // Second axis is sharded across both "data" and "model" devices. +// largeWeights := MakeShardSpec(mesh.Name()).AddReplicated().AddShardedAxis("data", "model") +// +// There are two advanced features supported but not tested (pls if you need let us know how it goes, or if you find +// any issues): +// +// 1. The tensor can also be sharded across mesh "sub-axes" -- seed detailed documentation in [1] +// 2. If using ShardSpec for hints, instead of mesh axes one can give an "open" (in StableHLO marked as "?") +// axis, with the semantics that XLA Shardy can choose any mesh axis (or axes) to shard the tensor. See [1]. +// +// [1] https://github.com/openxla/shardy/blob/main/docs/sharding_representation.md +type ShardSpec struct { + Mesh *DeviceMesh + Axes []TensorAxisSpec +} + +// TensorAxisSpec specifies how a tensor axis is to be sharded (or replicated). +// See details in ShardSpec. +// +// Usually, one would create this using ShardSpec.AddAxis or ShardSpec.AddReplicated +type TensorAxisSpec struct { + MeshAxes []MeshAxisSpec + Opened bool // If opened to further sharding. +} + +type MeshAxisSpec struct { + AxisName string + + // PreSize, Size are only set if defining a sub-axis of the mesh. + PreSize, Size int +} + +// NewShardSpec creates a new ShardSpec. +func NewShardSpec(mesh *DeviceMesh) *ShardSpec { + return &ShardSpec{mesh, make([]TensorAxisSpec, 0)} +} + +// AddShardedAxis adds a new sharded axis to the ShardSpec using one or more mesh axes. +// +// It returns itself, so calls can be chained. +func (s *ShardSpec) AddShardedAxis(meshAxisName string, moreMeshAxesNames ...string) *ShardSpec { + axisSpec := TensorAxisSpec{MeshAxes: []MeshAxisSpec{{AxisName: meshAxisName}}} + for _, meshAxisName := range moreMeshAxesNames { + axisSpec.MeshAxes = append(axisSpec.MeshAxes, MeshAxisSpec{AxisName: meshAxisName}) + } + s.Axes = append(s.Axes, axisSpec) + return s +} + +// AddReplicated adds a new replicated axis to the ShardSpec. +// +// It returns itself, so calls can be chained. +func (s *ShardSpec) AddReplicated() *ShardSpec { + s.Axes = append(s.Axes, TensorAxisSpec{}) + return s +} + +// Rank returns the number of axes this ShardSpec describes. +// +// Notice this may be smaller than the rank of the tensor using it: if a tensor axis is not defined in ShardSpec, +// it is assumed to be replicated. +func (s *ShardSpec) Rank() int { + return len(s.Axes) +} + +// IsReplicated returns true if the tensor is fully replicated +// (i.e., not sharded along any axis and not marked as "open"). +func (s *ShardSpec) IsReplicated() bool { + for _, axisSpec := range s.Axes { + if axisSpec.MeshAxes != nil || axisSpec.Opened { + return false + } + } + return true +} + +// Validate checks that the ShardSpec is valid for the given mesh. +func (s *ShardSpec) Validate() error { + for i, axisSpec := range s.Axes { + for _, meshAxisSpec := range axisSpec.MeshAxes { + axisName := meshAxisSpec.AxisName + if axisName == "" { + return errors.Errorf("ShardSpec axis %d refers to empty mesh axis name", i) + } + if _, ok := s.Mesh.nameToAxis[axisName]; !ok { + return errors.Errorf("ShardSpec axis #%d refers to unknown mesh axis %q", + i, axisName) + } + } + } + return nil +} From 574c673f50170e254563551a8839f94dd587d7b6 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 18 Nov 2025 09:34:00 +0000 Subject: [PATCH 02/30] feat(shardy): Implement ShardSpec.ToStableHLO and improve validation This change introduces two main improvements to the `shardy` package: 1. **`ShardSpec.ToStableHLO()` implementation:** A new method, `ToStableHLO()`, has been added to the `ShardSpec` struct. This method converts the sharding specification into the string format required by StableHLO, as detailed in the OpenXLA Shardy documentation. It correctly handles sharded, replicated, open, and sub-axis dimensions, and it ensures the output is deterministic by sorting the replicated axis names. 2. **Enhanced `ShardSpec.Validate()`:** The `Validate()` method has been updated to include more robust checks for sharding specifications that use sub-axes. It now verifies that the `PreSize` and `Size` of a sub-axis are mathematically compatible with the total size of the corresponding mesh axis, preventing invalid sharding configurations. A comprehensive suite of unit tests has been added in `shardspec_test.go` to validate both the new `ToStableHLO` method and the enhanced validation logic, ensuring the correctness and reliability of the implementation. --- types/shardy/shardspec.go | 67 +++++++++++++++++-- types/shardy/shardspec_test.go | 116 +++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 5 deletions(-) create mode 100644 types/shardy/shardspec_test.go diff --git a/types/shardy/shardspec.go b/types/shardy/shardspec.go index 1ae7cc6..ecff1ac 100644 --- a/types/shardy/shardspec.go +++ b/types/shardy/shardspec.go @@ -1,6 +1,10 @@ package shardy import ( + "fmt" + "sort" + "strings" + "github.com/pkg/errors" ) @@ -102,16 +106,69 @@ func (s *ShardSpec) IsReplicated() bool { // Validate checks that the ShardSpec is valid for the given mesh. func (s *ShardSpec) Validate() error { for i, axisSpec := range s.Axes { - for _, meshAxisSpec := range axisSpec.MeshAxes { + for j, meshAxisSpec := range axisSpec.MeshAxes { axisName := meshAxisSpec.AxisName if axisName == "" { - return errors.Errorf("ShardSpec axis %d refers to empty mesh axis name", i) + return errors.Errorf("ShardSpec tensor axis %d, mesh axis #%d refers to empty mesh axis name", i, j) } - if _, ok := s.Mesh.nameToAxis[axisName]; !ok { - return errors.Errorf("ShardSpec axis #%d refers to unknown mesh axis %q", - i, axisName) + axisIdx, ok := s.Mesh.nameToAxis[axisName] + if !ok { + return errors.Errorf("ShardSpec tensor axis %d, mesh axis #%d refers to unknown mesh axis %q", + i, j, axisName) + } + meshAxisSize := s.Mesh.shape[axisIdx] + + // Check sub-axis specification. + if meshAxisSpec.Size > 0 { + if meshAxisSpec.PreSize <= 0 { + return errors.Errorf("ShardSpec tensor axis %d, mesh axis #%d %q has invalid PreSize %d", + i, j, axisName, meshAxisSpec.PreSize) + } + if meshAxisSize%(meshAxisSpec.PreSize*meshAxisSpec.Size) != 0 { + return errors.Errorf("ShardSpec tensor axis %d, mesh axis #%d %q with PreSize %d and Size %d is not compatible with mesh axis of size %d", + i, j, axisName, meshAxisSpec.PreSize, meshAxisSpec.Size, meshAxisSize) + } } } } return nil } + +// ToStableHLO converts the ShardSpec to its StableHLO string representation. +// See details in: +// https://github.com/openxla/shardy/blob/main/docs/sharding_representation.md +func (s *ShardSpec) ToStableHLO() string { + var dimShardings []string + replicatedAxes := make(map[string]bool) + for _, axisName := range s.Mesh.axesNames { + replicatedAxes[axisName] = true + } + + for _, axisSpec := range s.Axes { + var hloAxes []string + for _, meshAxisSpec := range axisSpec.MeshAxes { + delete(replicatedAxes, meshAxisSpec.AxisName) + if meshAxisSpec.Size > 0 { + hloAxes = append(hloAxes, fmt.Sprintf("%s:(%d)%d", meshAxisSpec.AxisName, meshAxisSpec.PreSize, meshAxisSpec.Size)) + } else { + hloAxes = append(hloAxes, meshAxisSpec.AxisName) + } + } + if axisSpec.Opened { + hloAxes = append(hloAxes, "?") + } + dimShardings = append(dimShardings, fmt.Sprintf("{%s}", strings.Join(hloAxes, ", "))) + } + + var replicatedStrs []string + for axisName := range replicatedAxes { + replicatedStrs = append(replicatedStrs, axisName) + } + sort.Strings(replicatedStrs) + + replicatedPart := "" + if len(replicatedStrs) > 0 { + replicatedPart = fmt.Sprintf(", replicated={%s}", strings.Join(replicatedStrs, ", ")) + } + return fmt.Sprintf("sharding<@%s, [%s]%s>", s.Mesh.Name(), strings.Join(dimShardings, ", "), replicatedPart) +} diff --git a/types/shardy/shardspec_test.go b/types/shardy/shardspec_test.go new file mode 100644 index 0000000..df81791 --- /dev/null +++ b/types/shardy/shardspec_test.go @@ -0,0 +1,116 @@ +package shardy + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestShardSpec_ToStableHLO(t *testing.T) { + mesh, err := NewDeviceMesh("test_mesh", []int{4, 2}, []string{"z", "a"}) + require.NoError(t, err) + testCases := []struct { + name string + spec *ShardSpec + expected string + }{ + { + name: "Replicated", + spec: NewShardSpec(mesh).AddReplicated(), + expected: "sharding<@test_mesh, [{}], replicated={a, z}>", + }, + { + name: "Sharded", + spec: NewShardSpec(mesh).AddShardedAxis("z"), + expected: "sharding<@test_mesh, [{z}], replicated={a}>", + }, + { + name: "Sharded with multiple axes", + spec: NewShardSpec(mesh).AddShardedAxis("z", "a"), + expected: "sharding<@test_mesh, [{z, a}]>", + }, + { + name: "Sharded with sub-axis", + spec: &ShardSpec{ + Mesh: mesh, + Axes: []TensorAxisSpec{ + {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 1, Size: 2}}}, + }, + }, + expected: "sharding<@test_mesh, [{a:(1)2}], replicated={z}>", + }, + { + name: "Opened", + spec: &ShardSpec{Mesh: mesh, Axes: []TensorAxisSpec{{Opened: true}}}, + expected: "sharding<@test_mesh, [{?}], replicated={a, z}>", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expected, tc.spec.ToStableHLO()) + }) + } +} + +func TestShardSpec_Validate(t *testing.T) { + mesh, err := NewDeviceMesh("test_mesh", []int{2, 8}, []string{"z", "a"}) + require.NoError(t, err) + testCases := []struct { + name string + spec *ShardSpec + expectError bool + }{ + { + name: "Valid sharding", + spec: NewShardSpec(mesh).AddShardedAxis("z"), + expectError: false, + }, + { + name: "Unknown mesh axis", + spec: NewShardSpec(mesh).AddShardedAxis("x"), + expectError: true, + }, + { + name: "Valid sub-axis", + spec: &ShardSpec{ + Mesh: mesh, + Axes: []TensorAxisSpec{ + {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 2, Size: 4}}}, + }, + }, + expectError: false, + }, + { + name: "Invalid sub-axis (PreSize)", + spec: &ShardSpec{ + Mesh: mesh, + Axes: []TensorAxisSpec{ + {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 0, Size: 4}}}, + }, + }, + expectError: true, + }, + { + name: "Invalid sub-axis (Size)", + spec: &ShardSpec{ + Mesh: mesh, + Axes: []TensorAxisSpec{ + {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 2, Size: 5}}}, + }, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.spec.Validate() + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} From 68f91ef94e28fc61d48c0d24251f36a01897df12 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 08:00:11 +0100 Subject: [PATCH 03/30] Moved NormalizeIdentifier to utils to break dependency cycle. Added an alias in stablehlo in case end-users want to use it. --- internal/utils/utils.go | 24 ++++++++++++++++++++++++ stablehlo.go | 11 +++++++++++ 2 files changed, 35 insertions(+) create mode 100644 internal/utils/utils.go diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 0000000..c56831d --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,24 @@ +package utils + +// NormalizeIdentifier converts the name of an identifier (function name or function input parameter +// name) to a valid one: only letters, digits, and underscores are allowed. +// +// Invalid characters are replaced with underscores. +// If the name starts with a digit, it is prefixed with an underscore. +func NormalizeIdentifier(name string) string { + if name == "" { + return "" + } + result := make([]rune, 0, len(name)+1) + if name[0] >= '0' && name[0] <= '9' { + result = append(result, '_') + } + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' { + result = append(result, r) + } else { + result = append(result, '_') + } + } + return string(result) +} diff --git a/stablehlo.go b/stablehlo.go index 2e47c0b..33f6d2f 100644 --- a/stablehlo.go +++ b/stablehlo.go @@ -13,5 +13,16 @@ // See ToStableHLO documentation and specifications in https://openxla.org/stablehlo/spec package stablehlo +import "github.com/gomlx/stablehlo/internal/utils" + // Generates some trivial functions (binary and unary operators) automatically. //go:generate go run ./internal/cmd/ops_generator + +// NormalizeIdentifier converts the name of an identifier (function name or function input parameter +// name, etc.) to a valid one: only letters, digits, and underscores are allowed. +// +// Invalid characters are replaced with underscores. +// If the name starts with a digit, it is prefixed with an underscore. +func NormalizeIdentifier(name string) string { + return utils.NormalizeIdentifier(name) +} From 0802096266f2e0dc9b89e2323ab5657e298e80f1 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 08:00:45 +0100 Subject: [PATCH 04/30] Added checks that mesh and mesh axes names are valid StableHLO identifiers. Added DeviceMesh.ToStableHLO(). --- types/shardy/devicemesh.go | 41 ++++++++++++++++++++++-- types/shardy/devicemesh_test.go | 56 ++++++++++++++++++--------------- 2 files changed, 69 insertions(+), 28 deletions(-) diff --git a/types/shardy/devicemesh.go b/types/shardy/devicemesh.go index 7b1a53b..96660ad 100644 --- a/types/shardy/devicemesh.go +++ b/types/shardy/devicemesh.go @@ -38,13 +38,15 @@ type DeviceMesh struct { // NewDeviceMesh creates a new logical topology of a set of devices. // -// - shape: defines the number of devices along each mesh axis, one value per axis. -// - axesNames: the names of the mesh axes. One value per axis. +// - name: the name of the mesh, it must be a valid StableHLO identifier (see stablehlo.NormalizeIdentifier). +// - shape: defines the number of devices along each mesh axis, one value per axis. +// - axesNames: the names of the mesh axes. One value per axis. They must also be valid StableHLO identifiers +// (see stablehlo.NormalizeName). // // The default mapping of concrete devices numbers to the mesh is sequential, starting from 0, but it can be // changed with the DeviceMesh.SetDeviceAssignment() method. // -// For non-symmetric devices, where connection speed among the devices matter, a custom mapping can be provided +// For non-symmetric devices, where the connection speed among the devices matters, a custom mapping can be provided // with the DeviceMesh.WithDeviceMapping() method. func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, error) { if len(shape) != len(axisNames) { @@ -55,6 +57,21 @@ func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, e return nil, errors.New("DeviceMesh shape cannot be empty") } + // Normalize names: + if name != utils.NormalizeIdentifier(name) { + return nil, errors.Errorf( + "DeviceMesh name %q is not a valid StableHLO identifier, suggestion %q -- or use "+ + "stablehlo.NormalizeIdentifier()", name, utils.NormalizeIdentifier(name)) + } + axisNames = slices.Clone(axisNames) + for i, axisName := range axisNames { + if axisNames[i] != utils.NormalizeIdentifier(axisName) { + return nil, errors.Errorf( + "DeviceMesh axis name %q at index %d is not a valid StableHLO identifier, suggestion %q -- or use "+ + "stablehlo.NormalizeIdentifier()", axisName, i, utils.NormalizeIdentifier(axisName)) + } + } + numDevices := 1 nameToAxis := make(map[string]int, len(shape)) for i, name := range axisNames { @@ -269,3 +286,21 @@ func (m *DeviceMesh) ComputeReplicaGroups(axes []string) ([][]int, error) { return groups, nil } + +// ToStableHLO returns the StableHLO representation of the mesh, as it should be used in the module body. +// E.g.: sdy.mesh @mesh = <["data"=4, "model"=2]> +func (m *DeviceMesh) ToStableHLO() string { + var buf strings.Builder + w := func(format string, args ...any) { + buf.WriteString(fmt.Sprintf(format, args...)) + } + w("sdy.mesh @%s = <[", m.name) + for i, axisName := range m.axesNames { + if i > 0 { + w(", ") + } + w("%q=%d", axisName, m.shape[i]) + } + w("]>") + return buf.String() +} diff --git a/types/shardy/devicemesh_test.go b/types/shardy/devicemesh_test.go index e09be45..26574c3 100644 --- a/types/shardy/devicemesh_test.go +++ b/types/shardy/devicemesh_test.go @@ -11,39 +11,44 @@ import ( func TestDeviceMesh(t *testing.T) { t.Run("NewDeviceMesh_Valid", func(t *testing.T) { tests := []struct { - name string - shape []int - axisNames []string - wantRank int - wantNum int + name string + shape []int + axisNames []string + wantRank int + wantNum int + wantStableHLO string }{ { - name: "1D mesh", - shape: []int{8}, - axisNames: []string{"replica"}, - wantRank: 1, - wantNum: 8, + name: "1D mesh", + shape: []int{8}, + axisNames: []string{"replica"}, + wantRank: 1, + wantNum: 8, + wantStableHLO: `sdy.mesh @mesh = <["replica"=8]>`, }, { - name: "2D mesh", - shape: []int{2, 4}, - axisNames: []string{"x", "y"}, - wantRank: 2, - wantNum: 8, + name: "2D mesh", + shape: []int{2, 4}, + axisNames: []string{"x", "y"}, + wantRank: 2, + wantNum: 8, + wantStableHLO: `sdy.mesh @mesh = <["x"=2, "y"=4]>`, }, { - name: "3D mesh", - shape: []int{2, 2, 2}, - axisNames: []string{"x", "y", "z"}, - wantRank: 3, - wantNum: 8, + name: "3D mesh", + shape: []int{2, 2, 2}, + axisNames: []string{"x", "y", "z"}, + wantRank: 3, + wantNum: 8, + wantStableHLO: `sdy.mesh @mesh = <["x"=2, "y"=2, "z"=2]>`, }, { - name: "single device", - shape: []int{1}, - axisNames: []string{"replica"}, - wantRank: 1, - wantNum: 1, + name: "single device", + shape: []int{1}, + axisNames: []string{"replica"}, + wantRank: 1, + wantNum: 1, + wantStableHLO: `sdy.mesh @mesh = <["replica"=1]>`, }, } @@ -54,6 +59,7 @@ func TestDeviceMesh(t *testing.T) { assert.NotNil(t, mesh) assert.Equal(t, tt.wantRank, mesh.Rank()) assert.Equal(t, tt.wantNum, mesh.NumDevices()) + assert.Equal(t, tt.wantStableHLO, mesh.ToStableHLO()) }) } }) From 1259abb204bea631b23a03006d647efef99467ff Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 08:01:23 +0100 Subject: [PATCH 05/30] Added Builder.WithShardy() method. Added comments discouraging using other distributed (collective) mechanisms other than Shardy. --- builder.go | 78 ++++++++++++++++++++++++++++----------------------- collective.go | 15 ++++++++++ 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/builder.go b/builder.go index faf5276..08c2d4d 100644 --- a/builder.go +++ b/builder.go @@ -7,10 +7,11 @@ import ( "slices" "github.com/gomlx/stablehlo/types" + "github.com/gomlx/stablehlo/types/shardy" "github.com/pkg/errors" ) -// Builder is used to construct a StableHLO program. +// Builder is used to construct a StableHLO program (or "Module") // See details in New. type Builder struct { name string @@ -22,38 +23,19 @@ type Builder struct { // inlineUniqueID is a counter used to generate unique names for inlined functions values. inlineUniqueID int - // NumReplicas is the number of replicas for data parallelism. - NumReplicas int - // NumPartitions is the number of partitions for model parallelism. - NumPartitions int + // Mesh used for Shardy. + mesh *shardy.DeviceMesh + + // numReplicas is the number of replicas for data parallelism. + numReplicas int + // numPartitions is the number of partitions for model parallelism. + numPartitions int // nextChannelID is the next ID to be assigned in channel handles. // It is just a Unique ID. nextChannelID int } -// NormalizeIdentifier converts the name of an identifier (function name or function input parameter -// name) to a valid one: only letters, digits and underscores are allowed. -// -// Invalid characters are replaced with underscores. -// If the name starts with a digit, it is prefixed with an underscore. -// -// The name is normalized in place. -func NormalizeIdentifier(name string) string { - result := make([]rune, 0, len(name)+1) - if name[0] >= '0' && name[0] <= '9' { - result = append(result, '_') - } - for _, r := range name { - if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' { - result = append(result, r) - } else { - result = append(result, '_') - } - } - return string(result) -} - // New creates a new Builder object holding a computation graph in construction. // // From a builder you can create functions. @@ -62,7 +44,7 @@ func NormalizeIdentifier(name string) string { // You have to define the "main" function for your StableHLO program: you can use Builder.Main to do so, or // Builder.NewFunction("main",...), it's the same. // -// Once you are all set, call Builder.Build and it will return the StableHLO program as a []byte that can +// Once you are all set, call Builder.Build and it will return the StableHLO program (or "Module") as a []byte that can // be used with PJRT. // // See github.com/gomlx/gopjrt for a Go API to PJRT. @@ -74,15 +56,36 @@ func New(name string) *Builder { // WithNumReplicas sets the number of replicas (for data parallelism). // This is added as an attribute to the StableHLO module. +// +// Consider using WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func (b *Builder) WithNumReplicas(n int) *Builder { - b.NumReplicas = n + b.numReplicas = n return b } // WithNumPartitions sets the number of partitions (for model parallelism). // This is added as an attribute to the StableHLO module. +// +// Consider using WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func (b *Builder) WithNumPartitions(n int) *Builder { - b.NumPartitions = n + b.numPartitions = n + return b +} + +// WithShardy enables distributed computation across the devices selected by the mesh. +// This is the recommended way to do distributed (across devices) computation, and given the inputs +// with sharded information, Shardy will automatically distribute the computation, without you needing +// to specify any of the collective operations. +// +// See details of XLA Shardy in [1] +// +// [1] https://github.com/openxla/shardy +func (b *Builder) WithShardy(mesh *shardy.DeviceMesh) *Builder { + b.WithNumReplicas(1) + b.WithNumPartitions(mesh.NumDevices()) + return b } @@ -134,11 +137,11 @@ const IndentationStep = " " // getModuleAttributes returns the attributes for the StableHLO module (StableHLO code) generated. func (b *Builder) getModuleAttributes() []string { var attributes []string - if b.NumReplicas > 0 { - attributes = append(attributes, fmt.Sprintf("stablehlo.num_replicas = %d", b.NumReplicas)) + if b.numReplicas > 0 { + attributes = append(attributes, fmt.Sprintf("stablehlo.num_replicas = %d", b.numReplicas)) } - if b.NumPartitions > 0 { - attributes = append(attributes, fmt.Sprintf(" stablehlo.num_partitions = %d", b.NumPartitions)) + if b.numPartitions > 0 { + attributes = append(attributes, fmt.Sprintf(" stablehlo.num_partitions = %d", b.numPartitions)) } return attributes } @@ -177,10 +180,15 @@ func (b *Builder) Write(writer io.Writer) error { } w("%s", attr) } - w(" }") + w("}") } w(" {\n") + // Write Shardy mesh if needed: + if b.mesh != nil { + w("%s%s\n", IndentationStep, b.mesh.ToStableHLO()) + } + // Write non-inline functions: var count int for _, fn := range b.functions { diff --git a/collective.go b/collective.go index 93a23c2..116eb8d 100644 --- a/collective.go +++ b/collective.go @@ -50,6 +50,9 @@ func formatReplicaGroups(groups [][]int) literalStr { // Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device // numbers. E.g., `[[0, 1, 2, 3]]`. // - config: Optional configuration of the channels to be used. This is not needed for SPMD programs. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types.CollectiveConfig) (*Value, error) { op := optypes.CollectiveBroadcast fn := operand.fn @@ -95,6 +98,9 @@ func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types // Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device // numbers. E.g., `[[0, 1, 2, 3]]`. // - config: Optional configuration of the channels to be used. This is not needed for SPMD programs. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, config ...*types.CollectiveConfig) ( []*Value, error) { op := optypes.AllReduce @@ -153,6 +159,9 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, // - replicaGroups: A 2D array defining the communicating device groups. // - allGatherDim: The dimension along which to concatenate the operands. // - config: Optional configuration of the channels to be used. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config ...*types.CollectiveConfig) (*Value, error) { op := optypes.AllGather fn := operand.fn @@ -193,6 +202,9 @@ func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config . // - concatDimension: The dimension along which to concatenate the received chunks. // - splitCount: The number of chunks to split the operand into. This must match the size of the replica groups. // - config: Optional configuration of the channels to be used. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func AllToAll(operand *Value, replicaGroups [][]int, splitDimension, concatDimension, splitCount int, config ...*types.CollectiveConfig) (*Value, error) { op := optypes.AllToAll fn := operand.fn @@ -251,6 +263,9 @@ func formatSourceTargetPairs(pairs [][2]int) literalStr { // - operand: The tensor from the *local* replica. // - sourceTargetPairs: A 2D array where each inner array is a `[source, target]` pair of replica IDs. // - config: Optional configuration of the channels to be used. +// +// Consider using Builder.WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*types.CollectiveConfig) (*Value, error) { op := optypes.CollectivePermute fn := operand.fn From 64d506a687262a94422351fc915f5ee3336092ca Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 10:03:54 +0100 Subject: [PATCH 06/30] Renamed ShardSpec -> ShardingSpec --- .../shardy/{shardspec.go => shardingspec.go} | 66 ++++++++++++------- ...shardspec_test.go => shardingspec_test.go} | 24 +++---- 2 files changed, 53 insertions(+), 37 deletions(-) rename types/shardy/{shardspec.go => shardingspec.go} (65%) rename types/shardy/{shardspec_test.go => shardingspec_test.go} (80%) diff --git a/types/shardy/shardspec.go b/types/shardy/shardingspec.go similarity index 65% rename from types/shardy/shardspec.go rename to types/shardy/shardingspec.go index ecff1ac..be6cf2d 100644 --- a/types/shardy/shardspec.go +++ b/types/shardy/shardingspec.go @@ -5,10 +5,11 @@ import ( "sort" "strings" + "github.com/gomlx/stablehlo/types/shapes" "github.com/pkg/errors" ) -// ShardSpec (also known as PartitionSpec in JAX) defines how a logical tensor is to be sharded (partitioned) across +// ShardingSpec (also known as PartitionSpec in JAX) defines how a logical tensor is to be sharded (partitioned) across // a DeviceMesh. This is used by Shardy, and is based on its documentation in [1]. // // The definition is per axis of the logical tensor -- and not per axis of the Mesh, common confusion. @@ -34,19 +35,19 @@ import ( // any issues): // // 1. The tensor can also be sharded across mesh "sub-axes" -- seed detailed documentation in [1] -// 2. If using ShardSpec for hints, instead of mesh axes one can give an "open" (in StableHLO marked as "?") +// 2. If using ShardingSpec for hints, instead of mesh axes one can give an "open" (in StableHLO marked as "?") // axis, with the semantics that XLA Shardy can choose any mesh axis (or axes) to shard the tensor. See [1]. // // [1] https://github.com/openxla/shardy/blob/main/docs/sharding_representation.md -type ShardSpec struct { +type ShardingSpec struct { Mesh *DeviceMesh Axes []TensorAxisSpec } // TensorAxisSpec specifies how a tensor axis is to be sharded (or replicated). -// See details in ShardSpec. +// See details in ShardingSpec. // -// Usually, one would create this using ShardSpec.AddAxis or ShardSpec.AddReplicated +// Usually, one would create this using ShardingSpec.AddAxis or ShardingSpec.AddReplicated type TensorAxisSpec struct { MeshAxes []MeshAxisSpec Opened bool // If opened to further sharding. @@ -59,15 +60,15 @@ type MeshAxisSpec struct { PreSize, Size int } -// NewShardSpec creates a new ShardSpec. -func NewShardSpec(mesh *DeviceMesh) *ShardSpec { - return &ShardSpec{mesh, make([]TensorAxisSpec, 0)} +// NewShardingSpec creates a new ShardingSpec. +func NewShardingSpec(mesh *DeviceMesh) *ShardingSpec { + return &ShardingSpec{mesh, make([]TensorAxisSpec, 0)} } -// AddShardedAxis adds a new sharded axis to the ShardSpec using one or more mesh axes. +// AddShardedAxis adds a new sharded axis to the ShardingSpec using one or more mesh axes. // // It returns itself, so calls can be chained. -func (s *ShardSpec) AddShardedAxis(meshAxisName string, moreMeshAxesNames ...string) *ShardSpec { +func (s *ShardingSpec) AddShardedAxis(meshAxisName string, moreMeshAxesNames ...string) *ShardingSpec { axisSpec := TensorAxisSpec{MeshAxes: []MeshAxisSpec{{AxisName: meshAxisName}}} for _, meshAxisName := range moreMeshAxesNames { axisSpec.MeshAxes = append(axisSpec.MeshAxes, MeshAxisSpec{AxisName: meshAxisName}) @@ -76,25 +77,25 @@ func (s *ShardSpec) AddShardedAxis(meshAxisName string, moreMeshAxesNames ...str return s } -// AddReplicated adds a new replicated axis to the ShardSpec. +// AddReplicated adds a new replicated axis to the ShardingSpec. // // It returns itself, so calls can be chained. -func (s *ShardSpec) AddReplicated() *ShardSpec { +func (s *ShardingSpec) AddReplicated() *ShardingSpec { s.Axes = append(s.Axes, TensorAxisSpec{}) return s } -// Rank returns the number of axes this ShardSpec describes. +// Rank returns the number of axes this ShardingSpec describes. // -// Notice this may be smaller than the rank of the tensor using it: if a tensor axis is not defined in ShardSpec, +// Notice this may be smaller than the rank of the tensor using it: if a tensor axis is not defined in ShardingSpec, // it is assumed to be replicated. -func (s *ShardSpec) Rank() int { +func (s *ShardingSpec) Rank() int { return len(s.Axes) } // IsReplicated returns true if the tensor is fully replicated // (i.e., not sharded along any axis and not marked as "open"). -func (s *ShardSpec) IsReplicated() bool { +func (s *ShardingSpec) IsReplicated() bool { for _, axisSpec := range s.Axes { if axisSpec.MeshAxes != nil || axisSpec.Opened { return false @@ -103,17 +104,17 @@ func (s *ShardSpec) IsReplicated() bool { return true } -// Validate checks that the ShardSpec is valid for the given mesh. -func (s *ShardSpec) Validate() error { +// Validate checks that the ShardingSpec is valid for the given mesh. +func (s *ShardingSpec) Validate() error { for i, axisSpec := range s.Axes { for j, meshAxisSpec := range axisSpec.MeshAxes { axisName := meshAxisSpec.AxisName if axisName == "" { - return errors.Errorf("ShardSpec tensor axis %d, mesh axis #%d refers to empty mesh axis name", i, j) + return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d refers to empty mesh axis name", i, j) } axisIdx, ok := s.Mesh.nameToAxis[axisName] if !ok { - return errors.Errorf("ShardSpec tensor axis %d, mesh axis #%d refers to unknown mesh axis %q", + return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d refers to unknown mesh axis %q", i, j, axisName) } meshAxisSize := s.Mesh.shape[axisIdx] @@ -121,11 +122,11 @@ func (s *ShardSpec) Validate() error { // Check sub-axis specification. if meshAxisSpec.Size > 0 { if meshAxisSpec.PreSize <= 0 { - return errors.Errorf("ShardSpec tensor axis %d, mesh axis #%d %q has invalid PreSize %d", + return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d %q has invalid PreSize %d", i, j, axisName, meshAxisSpec.PreSize) } if meshAxisSize%(meshAxisSpec.PreSize*meshAxisSpec.Size) != 0 { - return errors.Errorf("ShardSpec tensor axis %d, mesh axis #%d %q with PreSize %d and Size %d is not compatible with mesh axis of size %d", + return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d %q with PreSize %d and Size %d is not compatible with mesh axis of size %d", i, j, axisName, meshAxisSpec.PreSize, meshAxisSpec.Size, meshAxisSize) } } @@ -134,10 +135,25 @@ func (s *ShardSpec) Validate() error { return nil } -// ToStableHLO converts the ShardSpec to its StableHLO string representation. +func (s *ShardingSpec) ValidateShape(shape shapes.Shape) error { + if s == nil { + // No sharding spec (nil) means fully replicated, and it's always valid for any shape. + return nil + } + err := s.Validate() + if err != nil { + return err + } + if s.Rank() > shape.Rank() { + return errors.Errorf("ShardingSpec shape rank %d is largers than tensor rank %d", s.Rank(), shape.Rank()) + } + return nil +} + +// ToStableHLO converts the ShardingSpec to its StableHLO string representation. // See details in: // https://github.com/openxla/shardy/blob/main/docs/sharding_representation.md -func (s *ShardSpec) ToStableHLO() string { +func (s *ShardingSpec) ToStableHLO() string { var dimShardings []string replicatedAxes := make(map[string]bool) for _, axisName := range s.Mesh.axesNames { @@ -170,5 +186,5 @@ func (s *ShardSpec) ToStableHLO() string { if len(replicatedStrs) > 0 { replicatedPart = fmt.Sprintf(", replicated={%s}", strings.Join(replicatedStrs, ", ")) } - return fmt.Sprintf("sharding<@%s, [%s]%s>", s.Mesh.Name(), strings.Join(dimShardings, ", "), replicatedPart) + return fmt.Sprintf("#sdy.sharding<@%s, [%s]%s>", s.Mesh.Name(), strings.Join(dimShardings, ", "), replicatedPart) } diff --git a/types/shardy/shardspec_test.go b/types/shardy/shardingspec_test.go similarity index 80% rename from types/shardy/shardspec_test.go rename to types/shardy/shardingspec_test.go index df81791..f7007ee 100644 --- a/types/shardy/shardspec_test.go +++ b/types/shardy/shardingspec_test.go @@ -11,27 +11,27 @@ func TestShardSpec_ToStableHLO(t *testing.T) { require.NoError(t, err) testCases := []struct { name string - spec *ShardSpec + spec *ShardingSpec expected string }{ { name: "Replicated", - spec: NewShardSpec(mesh).AddReplicated(), + spec: NewShardingSpec(mesh).AddReplicated(), expected: "sharding<@test_mesh, [{}], replicated={a, z}>", }, { name: "Sharded", - spec: NewShardSpec(mesh).AddShardedAxis("z"), + spec: NewShardingSpec(mesh).AddShardedAxis("z"), expected: "sharding<@test_mesh, [{z}], replicated={a}>", }, { name: "Sharded with multiple axes", - spec: NewShardSpec(mesh).AddShardedAxis("z", "a"), + spec: NewShardingSpec(mesh).AddShardedAxis("z", "a"), expected: "sharding<@test_mesh, [{z, a}]>", }, { name: "Sharded with sub-axis", - spec: &ShardSpec{ + spec: &ShardingSpec{ Mesh: mesh, Axes: []TensorAxisSpec{ {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 1, Size: 2}}}, @@ -41,7 +41,7 @@ func TestShardSpec_ToStableHLO(t *testing.T) { }, { name: "Opened", - spec: &ShardSpec{Mesh: mesh, Axes: []TensorAxisSpec{{Opened: true}}}, + spec: &ShardingSpec{Mesh: mesh, Axes: []TensorAxisSpec{{Opened: true}}}, expected: "sharding<@test_mesh, [{?}], replicated={a, z}>", }, } @@ -58,22 +58,22 @@ func TestShardSpec_Validate(t *testing.T) { require.NoError(t, err) testCases := []struct { name string - spec *ShardSpec + spec *ShardingSpec expectError bool }{ { name: "Valid sharding", - spec: NewShardSpec(mesh).AddShardedAxis("z"), + spec: NewShardingSpec(mesh).AddShardedAxis("z"), expectError: false, }, { name: "Unknown mesh axis", - spec: NewShardSpec(mesh).AddShardedAxis("x"), + spec: NewShardingSpec(mesh).AddShardedAxis("x"), expectError: true, }, { name: "Valid sub-axis", - spec: &ShardSpec{ + spec: &ShardingSpec{ Mesh: mesh, Axes: []TensorAxisSpec{ {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 2, Size: 4}}}, @@ -83,7 +83,7 @@ func TestShardSpec_Validate(t *testing.T) { }, { name: "Invalid sub-axis (PreSize)", - spec: &ShardSpec{ + spec: &ShardingSpec{ Mesh: mesh, Axes: []TensorAxisSpec{ {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 0, Size: 4}}}, @@ -93,7 +93,7 @@ func TestShardSpec_Validate(t *testing.T) { }, { name: "Invalid sub-axis (Size)", - spec: &ShardSpec{ + spec: &ShardingSpec{ Mesh: mesh, Axes: []TensorAxisSpec{ {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 2, Size: 5}}}, From 663f568c7a64e988189dc4675cf00370585ffa13 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 10:04:02 +0100 Subject: [PATCH 07/30] Adding optional ShardingSpec to the inputs and outputs of the functions. --- builder.go | 10 ++++++ function.go | 98 ++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/builder.go b/builder.go index 08c2d4d..052ca5d 100644 --- a/builder.go +++ b/builder.go @@ -89,6 +89,16 @@ func (b *Builder) WithShardy(mesh *shardy.DeviceMesh) *Builder { return b } +// Mesh returns the mesh configured with WithShardy. +func (b *Builder) Mesh() *shardy.DeviceMesh { + return b.mesh +} + +// NewShardingSpec creates a new ShardingSpec using the mesh configured with WithShardy. +func (b *Builder) NewShardingSpec() *shardy.ShardingSpec { + return shardy.NewShardingSpec(b.mesh) +} + // elementWriter represents elements of ToStableHLO that know how to write themselves. type elementWriter interface { Write(w io.Writer, indentation string) error diff --git a/function.go b/function.go index 19aed57..b49f364 100644 --- a/function.go +++ b/function.go @@ -10,6 +10,7 @@ import ( "github.com/gomlx/stablehlo/internal/optypes" "github.com/gomlx/stablehlo/shapeinference" "github.com/gomlx/stablehlo/types/shapes" + "github.com/gomlx/stablehlo/types/shardy" "github.com/pkg/errors" ) @@ -23,9 +24,15 @@ type Function struct { // Inputs to the function. Inputs []*Value + // InputsShardingSpecs are the sharding specs for the inputs. Optional. + InputsShardingSpecs []*shardy.ShardingSpec + // Outputs types of the function. Outputs []shapes.Shape + // OutputsShardingSpecs are the sharding specs for the inputs. Optional. + OutputsShardingSpecs []*shardy.ShardingSpec + // Statements in the function body. Statements []*Statement @@ -82,10 +89,17 @@ func (fn *Function) newValue(shape shapes.Shape) (v *Value) { // It picks a default unique name for the input parameter, you can also // provide a name with NamedInput. func (fn *Function) Input(shape shapes.Shape) (*Value, error) { + return fn.InputWithSharding(shape, nil) +} + +func (fn *Function) InputWithSharding(shape shapes.Shape, shardingSpec *shardy.ShardingSpec) (*Value, error) { rootFn := fn.findRootFn() - value, err := fn.NamedInput(fmt.Sprintf("arg%d", rootFn.nextArgID), shape) + value, err := fn.NamedInputWithSharding(fmt.Sprintf("arg%d", rootFn.nextArgID), shape, shardingSpec) + if err != nil { + return nil, err + } rootFn.nextArgID++ - return value, err + return value, nil } // NamedInput creates a new input parameter for a function with the given name -- it @@ -98,6 +112,21 @@ func (fn *Function) Input(shape shapes.Shape) (*Value, error) { // Names are used in the StableHLO code and may be helpful for debugging, but // otherwise have no impact. func (fn *Function) NamedInput(name string, shape shapes.Shape) (*Value, error) { + return fn.NamedInputWithSharding(name, shape, nil) +} + +// NamedInputWithSharding creates a new input parameter for a function with the given name -- it +// must be a unique input name -- and sharding specification for distributed computation. +// +// The shardingSpec can be nil: the default is a replicated input across all devices. +// +// The name is passed through ConvertToValidName, which converts any non-digit or ASCII letter to an underscore. +// +// Names with the format "%d" and "arg%d" are reserved for the default input parameters. +// +// Names are used in the StableHLO code and may be helpful for debugging, but otherwise have no impact. +func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape, + shardingSpec *shardy.ShardingSpec) (*Value, error) { value := &Value{ fn: fn, name: ConvertToValidName(name), @@ -108,7 +137,17 @@ func (fn *Function) NamedInput(name string, shape shapes.Shape) (*Value, error) return nil, errors.Errorf("duplicate input name %q with input #%d", value.name, i) } } + if shardingSpec != nil { + if shardingSpec.Mesh != fn.Builder.mesh { + return nil, errors.Errorf("sharding spec mesh %s doesn't match the stablehlo.Builder mesh %s", + shardingSpec.Mesh, fn.Builder.mesh) + } + if err := shardingSpec.ValidateShape(shape); err != nil { + return nil, err + } + } fn.Inputs = append(fn.Inputs, value) + fn.InputsShardingSpecs = append(fn.InputsShardingSpecs, shardingSpec) return value, nil } @@ -180,33 +219,74 @@ func (fn *Function) ConstantFromFlatAndDimensions(flat any, dimensions ...int) ( // // There can be only one return statement from a Function, and it must be the last // operation of a function. -func (fn *Function) Return(firstValue *Value, otherValues ...*Value) error { +// +// If you are doing distributed computation, you can use WithReturnShardingSpecs to specify +// the sharding requirements for each of the return values. +func (fn *Function) Return(values ...*Value) error { if fn.Returned { return errors.Errorf("Function.Return already called for %q", fn.Name) } + if len(values) == 0 { + return errors.New("Function.Return requires at least one return value") + } fn.Returned = true - allValues := make([]*Value, 1, len(otherValues)+1) - allValues[0] = firstValue - allValues = append(allValues, otherValues...) - outputShapes := make([]shapes.Shape, len(allValues)) - for i, value := range allValues { + outputShapes := make([]shapes.Shape, len(values)) + for i, value := range values { if value.fn != fn { return errors.New("Function.Return given values that are not owned by the function") } outputShapes[i] = value.shape } fn.Outputs = outputShapes + fn.OutputsShardingSpecs = make([]*shardy.ShardingSpec, len(values)) // All default to nil. stmt := &Statement{ Builder: fn.Builder, Function: fn, OpType: optypes.FuncReturn, - Inputs: allValues, + Inputs: values, } fn.Statements = append(fn.Statements, stmt) return nil } +// WithReturnShardingSpecs specify the sharding requirements of the return values. +// It should be used after Return is called. +// +// You have to provide one spec per output used in Return. But nil spec values are valid, +// the default being a replicated input across all devices. +// +// The specs must use the same mesh as the stablehlo.Builder. Mixing meshes will cause an error. +// See Builder.NewShardingSpec. +func (fn *Function) WithReturnShardingSpecs(specs ...*shardy.ShardingSpec) error { + if !fn.Returned { + return errors.Errorf( + "Function.WithReturnShardingSpecs called for %q, but no Return hasn't been called yet", fn.Name) + } + if len(specs) != len(fn.Outputs) { + return errors.Errorf( + "Function.WithReturnShardingSpecs called for %q, but the number of sharding specs (%d) doesn't match "+ + "the number of return values (%d)", fn.Name, len(specs), len(fn.Outputs)) + } + for i, spec := range specs { + if spec == nil { + continue + } + if spec.Mesh != fn.Builder.mesh { + return errors.Errorf( + "Function.WithReturnShardingSpecs called for %q, but the sharding spec #%d uses a different mesh "+ + "(%s) than the function's builder (%s)", + fn.Name, i, spec.Mesh, fn.Builder.mesh) + } + err := spec.ValidateShape(fn.Outputs[i]) + if err != nil { + return errors.Wrapf(err, "Function.WithReturnShardingSpecs called for %q, but the sharding spec #%d is invalid", fn.Name, i) + } + } + fn.OutputsShardingSpecs = specs + return nil +} + // Iota creates a constant of the given shape with increasing numbers (starting from 0) // on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0) // returns [[0 0][1 1]]. From cc61929931268af662cbabac7a6f22c824b26d61 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 10:58:42 +0100 Subject: [PATCH 08/30] Fixed tests with correct stablehlo of shardingspec. --- types/shardy/shardingspec_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/types/shardy/shardingspec_test.go b/types/shardy/shardingspec_test.go index f7007ee..5e97424 100644 --- a/types/shardy/shardingspec_test.go +++ b/types/shardy/shardingspec_test.go @@ -17,17 +17,17 @@ func TestShardSpec_ToStableHLO(t *testing.T) { { name: "Replicated", spec: NewShardingSpec(mesh).AddReplicated(), - expected: "sharding<@test_mesh, [{}], replicated={a, z}>", + expected: "#sdy.sharding<@test_mesh, [{}], replicated={a, z}>", }, { name: "Sharded", spec: NewShardingSpec(mesh).AddShardedAxis("z"), - expected: "sharding<@test_mesh, [{z}], replicated={a}>", + expected: "#sdy.sharding<@test_mesh, [{z}], replicated={a}>", }, { name: "Sharded with multiple axes", spec: NewShardingSpec(mesh).AddShardedAxis("z", "a"), - expected: "sharding<@test_mesh, [{z, a}]>", + expected: "#sdy.sharding<@test_mesh, [{z, a}]>", }, { name: "Sharded with sub-axis", @@ -37,12 +37,12 @@ func TestShardSpec_ToStableHLO(t *testing.T) { {MeshAxes: []MeshAxisSpec{{AxisName: "a", PreSize: 1, Size: 2}}}, }, }, - expected: "sharding<@test_mesh, [{a:(1)2}], replicated={z}>", + expected: "#sdy.sharding<@test_mesh, [{a:(1)2}], replicated={z}>", }, { name: "Opened", spec: &ShardingSpec{Mesh: mesh, Axes: []TensorAxisSpec{{Opened: true}}}, - expected: "sharding<@test_mesh, [{?}], replicated={a, z}>", + expected: "#sdy.sharding<@test_mesh, [{?}], replicated={a, z}>", }, } From c858e73dd6f5328bc1f0c9c378cad9672a0d41f9 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 19 Nov 2025 10:28:55 +0000 Subject: [PATCH 09/30] feat: Add attribute support to function signatures This change introduces the ability to add attributes to function inputs and outputs in the `stablehlo.Builder`. This is essential for generating StableHLO modules with XLA Shardy annotations. Key changes: - The `Value` struct now includes an `Attributes` map. - `Function.Write` has been updated to serialize these attributes for both inputs and outputs. - The attribute writing logic has been refactored into a reusable `writeAttributes` function. - A new test case has been added to validate the generation of sharded StableHLO modules. --- builder.go | 2 +- collective.go | 2 +- function.go | 132 +++++++++++++++++++++++++--------------------- ops.go | 31 +++++++++-- stablehlo_test.go | 51 ++++++++++++++++++ statement.go | 58 +++++++++++--------- value.go | 7 +-- 7 files changed, 191 insertions(+), 92 deletions(-) diff --git a/builder.go b/builder.go index 052ca5d..9f9369d 100644 --- a/builder.go +++ b/builder.go @@ -83,9 +83,9 @@ func (b *Builder) WithNumPartitions(n int) *Builder { // // [1] https://github.com/openxla/shardy func (b *Builder) WithShardy(mesh *shardy.DeviceMesh) *Builder { + b.mesh = mesh b.WithNumReplicas(1) b.WithNumPartitions(mesh.NumDevices()) - return b } diff --git a/collective.go b/collective.go index 116eb8d..123a7ed 100644 --- a/collective.go +++ b/collective.go @@ -128,7 +128,7 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, outputShapes, err := shapeinference.AllReduce( valuesToShapes(operands), valuesToShapes(computation.Inputs), - computation.Outputs, + valuesToShapes(computation.Outputs), replicaGroups) if err != nil { return nil, err diff --git a/function.go b/function.go index b49f364..7c78c74 100644 --- a/function.go +++ b/function.go @@ -24,14 +24,8 @@ type Function struct { // Inputs to the function. Inputs []*Value - // InputsShardingSpecs are the sharding specs for the inputs. Optional. - InputsShardingSpecs []*shardy.ShardingSpec - - // Outputs types of the function. - Outputs []shapes.Shape - - // OutputsShardingSpecs are the sharding specs for the inputs. Optional. - OutputsShardingSpecs []*shardy.ShardingSpec + // Outputs of the function. + Outputs []*Value // Statements in the function body. Statements []*Statement @@ -89,12 +83,23 @@ func (fn *Function) newValue(shape shapes.Shape) (v *Value) { // It picks a default unique name for the input parameter, you can also // provide a name with NamedInput. func (fn *Function) Input(shape shapes.Shape) (*Value, error) { - return fn.InputWithSharding(shape, nil) + return fn.InputWithShardingAndAttributes(shape, nil, nil) } +// InputWithSharding creates a new input with the given sharding specification. func (fn *Function) InputWithSharding(shape shapes.Shape, shardingSpec *shardy.ShardingSpec) (*Value, error) { + return fn.InputWithShardingAndAttributes(shape, shardingSpec, nil) +} + +// InputWithAttributes creates a new input with the given attributes. +func (fn *Function) InputWithAttributes(shape shapes.Shape, attributes map[string]any) (*Value, error) { + return fn.InputWithShardingAndAttributes(shape, nil, attributes) +} + +// InputWithShardingAndAttributes creates a new input with the given sharding specification and attributes. +func (fn *Function) InputWithShardingAndAttributes(shape shapes.Shape, shardingSpec *shardy.ShardingSpec, attributes map[string]any) (*Value, error) { rootFn := fn.findRootFn() - value, err := fn.NamedInputWithSharding(fmt.Sprintf("arg%d", rootFn.nextArgID), shape, shardingSpec) + value, err := fn.NamedInputWithShardingAndAttributes(fmt.Sprintf("arg%d", rootFn.nextArgID), shape, shardingSpec, attributes) if err != nil { return nil, err } @@ -112,11 +117,24 @@ func (fn *Function) InputWithSharding(shape shapes.Shape, shardingSpec *shardy.S // Names are used in the StableHLO code and may be helpful for debugging, but // otherwise have no impact. func (fn *Function) NamedInput(name string, shape shapes.Shape) (*Value, error) { - return fn.NamedInputWithSharding(name, shape, nil) + return fn.NamedInputWithShardingAndAttributes(name, shape, nil, nil) } // NamedInputWithSharding creates a new input parameter for a function with the given name -- it // must be a unique input name -- and sharding specification for distributed computation. +func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape, + shardingSpec *shardy.ShardingSpec) (*Value, error) { + return fn.NamedInputWithShardingAndAttributes(name, shape, shardingSpec, nil) +} + +// NamedInputWithAttributes creates a new input parameter for a function with the given name and attributes. +func (fn *Function) NamedInputWithAttributes(name string, shape shapes.Shape, + attributes map[string]any) (*Value, error) { + return fn.NamedInputWithShardingAndAttributes(name, shape, nil, attributes) +} + +// NamedInputWithShardingAndAttributes creates a new input parameter for a function with the given name -- it +// must be a unique input name -- and sharding specification for distributed computation. // // The shardingSpec can be nil: the default is a replicated input across all devices. // @@ -125,12 +143,13 @@ func (fn *Function) NamedInput(name string, shape shapes.Shape) (*Value, error) // Names with the format "%d" and "arg%d" are reserved for the default input parameters. // // Names are used in the StableHLO code and may be helpful for debugging, but otherwise have no impact. -func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape, - shardingSpec *shardy.ShardingSpec) (*Value, error) { +func (fn *Function) NamedInputWithShardingAndAttributes(name string, shape shapes.Shape, + shardingSpec *shardy.ShardingSpec, attributes map[string]any) (*Value, error) { value := &Value{ - fn: fn, - name: ConvertToValidName(name), - shape: shape, + fn: fn, + name: ConvertToValidName(name), + shape: shape, + Attributes: attributes, } for i, input := range fn.Inputs { if input.name == value.name { @@ -138,6 +157,10 @@ func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape, } } if shardingSpec != nil { + if value.Attributes == nil { + value.Attributes = make(map[string]any) + } + value.Attributes["sdy.sharding"] = shardingSpec if shardingSpec.Mesh != fn.Builder.mesh { return nil, errors.Errorf("sharding spec mesh %s doesn't match the stablehlo.Builder mesh %s", shardingSpec.Mesh, fn.Builder.mesh) @@ -147,7 +170,6 @@ func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape, } } fn.Inputs = append(fn.Inputs, value) - fn.InputsShardingSpecs = append(fn.InputsShardingSpecs, shardingSpec) return value, nil } @@ -223,22 +245,49 @@ func (fn *Function) ConstantFromFlatAndDimensions(flat any, dimensions ...int) ( // If you are doing distributed computation, you can use WithReturnShardingSpecs to specify // the sharding requirements for each of the return values. func (fn *Function) Return(values ...*Value) error { + attributes := make([]map[string]any, len(values)) + return fn.ReturnWithAttributes(values, attributes) +} + +// ReturnWithSharding is a convenience function to call ReturnWithAttributes with the given sharding specifications. +func (fn *Function) ReturnWithSharding(values []*Value, shardingSpecs []*shardy.ShardingSpec) error { + if len(values) != len(shardingSpecs) { + return errors.Errorf("Function.ReturnWithSharding requires the same number of values and sharding specs, got %d and %d", len(values), len(shardingSpecs)) + } + attributes := make([]map[string]any, len(values)) + for i, shardingSpec := range shardingSpecs { + if shardingSpec != nil { + attributes[i] = map[string]any{"sdy.sharding": shardingSpec} + } + } + return fn.ReturnWithAttributes(values, attributes) +} + +// ReturnWithAttributes adds a return statement to the function with the given return values and attributes. +func (fn *Function) ReturnWithAttributes(values []*Value, attributes []map[string]any) error { if fn.Returned { return errors.Errorf("Function.Return already called for %q", fn.Name) } if len(values) == 0 { return errors.New("Function.Return requires at least one return value") } + if len(values) != len(attributes) { + return errors.Errorf("Function.ReturnWithAttributes requires the same number of values and attributes, got %d and %d", len(values), len(attributes)) + } fn.Returned = true - outputShapes := make([]shapes.Shape, len(values)) + outputValues := make([]*Value, len(values)) for i, value := range values { if value.fn != fn { return errors.New("Function.Return given values that are not owned by the function") } - outputShapes[i] = value.shape + outputValues[i] = &Value{ + fn: fn, + name: value.name, + shape: value.shape, + Attributes: attributes[i], + } } - fn.Outputs = outputShapes - fn.OutputsShardingSpecs = make([]*shardy.ShardingSpec, len(values)) // All default to nil. + fn.Outputs = outputValues stmt := &Statement{ Builder: fn.Builder, @@ -250,43 +299,6 @@ func (fn *Function) Return(values ...*Value) error { return nil } -// WithReturnShardingSpecs specify the sharding requirements of the return values. -// It should be used after Return is called. -// -// You have to provide one spec per output used in Return. But nil spec values are valid, -// the default being a replicated input across all devices. -// -// The specs must use the same mesh as the stablehlo.Builder. Mixing meshes will cause an error. -// See Builder.NewShardingSpec. -func (fn *Function) WithReturnShardingSpecs(specs ...*shardy.ShardingSpec) error { - if !fn.Returned { - return errors.Errorf( - "Function.WithReturnShardingSpecs called for %q, but no Return hasn't been called yet", fn.Name) - } - if len(specs) != len(fn.Outputs) { - return errors.Errorf( - "Function.WithReturnShardingSpecs called for %q, but the number of sharding specs (%d) doesn't match "+ - "the number of return values (%d)", fn.Name, len(specs), len(fn.Outputs)) - } - for i, spec := range specs { - if spec == nil { - continue - } - if spec.Mesh != fn.Builder.mesh { - return errors.Errorf( - "Function.WithReturnShardingSpecs called for %q, but the sharding spec #%d uses a different mesh "+ - "(%s) than the function's builder (%s)", - fn.Name, i, spec.Mesh, fn.Builder.mesh) - } - err := spec.ValidateShape(fn.Outputs[i]) - if err != nil { - return errors.Wrapf(err, "Function.WithReturnShardingSpecs called for %q, but the sharding spec #%d is invalid", fn.Name, i) - } - } - fn.OutputsShardingSpecs = specs - return nil -} - // Iota creates a constant of the given shape with increasing numbers (starting from 0) // on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0) // returns [[0 0][1 1]]. @@ -355,6 +367,7 @@ func (fn *Function) Write(writer io.Writer, indentation string) error { } we(input, nextIndent) w(": %s", input.shape.ToStableHLO()) + writeAttributes(writer, indentation, input.Attributes, w) } if isClosure { @@ -368,7 +381,8 @@ func (fn *Function) Write(writer io.Writer, indentation string) error { if i > 0 { w(", ") } - w("%s", output.ToStableHLO()) + w(output.shape.ToStableHLO()) + writeAttributes(writer, indentation, output.Attributes, w) } if len(fn.Outputs) > 1 { w(")") diff --git a/ops.go b/ops.go index a9066af..2021c92 100644 --- a/ops.go +++ b/ops.go @@ -216,6 +216,31 @@ type DotGeneralBuilder struct { algorithm *types.DotGeneralAlgorithm } +// DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications +// for a general vector product -- a generalized "Einsum". Each axis can be: +// - Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions +// must match in lhs and rhs. +// - Crossed (default), in which case the output is the combination (concatenation) of the +// dimensions. +// - Contracted (contracting axes), where the output does multiply the values and reduce sum +// those dimensions. +// +// It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' +// non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. +// It provides the basic means of implementing Einsum. +// +// Because there are optional parameters, this function returns a DotGeneralBuilder that can +// be further configured. Call DotGeneralBuilder.Done to get the final DotGeneral node. +func Dot(lhs, rhs *Value) (*Value, error) { + if lhs.Shape().Rank() != 2 || rhs.Shape().Rank() != 2 { + return nil, errors.Errorf("Dot only supports rank-2 tensors, got %d and %d", lhs.Shape().Rank(), rhs.Shape().Rank()) + } + return DotGeneral( + lhs, []int{1}, nil, + rhs, []int{0}, nil, + ).Done() +} + // DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications // for a general vector product -- a generalized "Einsum". Each axis can be: // - Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions @@ -627,7 +652,7 @@ func MultiReduce(inputs, initialValues []*Value, reductionFn *Function, axes ... outputsShapes, err := shapeinference.Reduce( valuesToShapes(inputs), valuesToShapes(initialValues), - valuesToShapes(reductionFn.Inputs), reductionFn.Outputs, + valuesToShapes(reductionFn.Inputs), valuesToShapes(reductionFn.Outputs), axes) if err != nil { return nil, err @@ -845,7 +870,7 @@ func MultiScatter(inputs []*Value, scatterIndices *Value, updates []*Value, updateWindowAxes, insertedWindowAxes, inputBatchingAxes, scatterIndicesBatchingAxes, indexedInputAxes, indexVectorAxis, - updateComputationInputShapes, updateComputationFn.Outputs) + updateComputationInputShapes, valuesToShapes(updateComputationFn.Outputs)) if err != nil { return nil, err } @@ -1241,7 +1266,7 @@ func MultiReduceWindow(inputs, initialValues []*Value, reductionFn *Function, outputsShapes, err := shapeinference.ReduceWindow( valuesToShapes(inputs), valuesToShapes(initialValues), - valuesToShapes(reductionFn.Inputs), reductionFn.Outputs, + valuesToShapes(reductionFn.Inputs), valuesToShapes(reductionFn.Outputs), windowDimensions, strides, inputDilations, windowDilations, paddings) if err != nil { diff --git a/stablehlo_test.go b/stablehlo_test.go index 375a1bb..56851b0 100644 --- a/stablehlo_test.go +++ b/stablehlo_test.go @@ -6,6 +6,7 @@ import ( "github.com/gomlx/gopjrt/dtypes" "github.com/gomlx/stablehlo/types/shapes" + "github.com/gomlx/stablehlo/types/shardy" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,6 +43,56 @@ func TestBuilder(t *testing.T) { } }) + t.Run("Sharding", func(t *testing.T) { + b := New(t.Name()) + mesh, err := shardy.NewDeviceMesh("mesh", []int{1, 1}, []string{"data", "model"}) + require.NoError(t, err) + b.WithShardy(mesh) + fn := b.Main() + + arg0 := must(fn.NamedInputWithShardingAndAttributes( + "arg0", + shapes.Make(dtypes.F32, 16, 128), + shardy.NewShardingSpec(mesh).AddShardedAxis("data"), + nil, + )) + arg1 := must(fn.NamedInputWithSharding( + "arg1", + shapes.Make(dtypes.F32, 128, 256), + shardy.NewShardingSpec(mesh).AddShardedAxis("model"), + )) + + tanh := must(Tanh(arg0)) + dot := must(Dot(tanh, arg1)) + require.NoError(t, fn.ReturnWithAttributes( + []*Value{dot}, + []map[string]any{{"jax.result_info": "result"}})) + + program := string(must(b.Build())) + fmt.Printf("%s program:\n%s", t.Name(), program) + want := `module @TestBuilder_Sharding attributes {stablehlo.num_replicas = 1, stablehlo.num_partitions = 1} { + sdy.mesh @mesh = <["data"=1, "model"=1]> + func.func @main(%arg0: tensor<16x128xf32> { sdy.sharding = #sdy.sharding<@mesh, [{data}], replicated={model}> }, %arg1: tensor<128x256xf32> { sdy.sharding = #sdy.sharding<@mesh, [{model}], replicated={data}> }) -> tensor<16x256xf32> { jax.result_info = "result" } { + %0 = "stablehlo.tanh"(%arg0) : (tensor<16x128xf32>) -> tensor<16x128xf32> + %1 = "stablehlo.dot_general"(%0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] +>, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32> + "stablehlo.return"(%1) : (tensor<16x256xf32>) -> () + } +} +` + if program != want { + fmt.Printf(" Failed. Wanted the following program:\n%s", want) + t.Fatal("programs don't match") + } + }) + t.Run("with inputs", func(t *testing.T) { builder := New(t.Name()) shape := shapes.Make(dtypes.Float64) diff --git a/statement.go b/statement.go index 10573a0..2d75ccb 100644 --- a/statement.go +++ b/statement.go @@ -103,31 +103,7 @@ func (s *Statement) Write(writer io.Writer, indentation string) error { } // Write attributes: - if len(s.Attributes) > 0 { - if len(s.Attributes) == 1 { - for key, value := range s.Attributes { - literalValue := literalToStableHLO(value) - if strings.Index(literalValue, "\n") == -1 { - w(" { %s = %s }", key, literalValue) - } else { - literalValue = strings.ReplaceAll(literalValue, "\n", "\n"+nextIndentation) - w(" {\n%s%s = %s\n }", nextIndentation, key, literalValue) - } - } - } else { - // One attribute per line: - w(" {") - first := true - for key, value := range s.Attributes { - if !first { - w(",") - } - first = false - w("\n%s%s = %s", nextIndentation, key, literalToStableHLO(value)) - } - w("\n%s}", indentation) - } - } + writeAttributes(writer, indentation, s.Attributes, w) // Write signature: w(" : (") @@ -160,6 +136,38 @@ func (s *Statement) Write(writer io.Writer, indentation string) error { return err } +// writeAttributes writes a map of attributes to the writer. +// The w function is the one provided by the caller to handle errors. +func writeAttributes(writer io.Writer, indentation string, attributes map[string]any, w func(format string, args ...any)) { + if len(attributes) == 0 { + return + } + nextIndentation := indentation + IndentationStep + if len(attributes) == 1 { + for key, value := range attributes { + literalValue := literalToStableHLO(value) + if strings.Index(literalValue, "\n") == -1 { + w(" { %s = %s }", key, literalValue) + } else { + literalValue = strings.ReplaceAll(literalValue, "\n", "\n"+nextIndentation) + w(" {\n%s%s = %s\n }", nextIndentation, key, literalValue) + } + } + } else { + // One attribute per line: + w(" {") + first := true + for key, value := range attributes { + if !first { + w(",") + } + first = false + w("\n%s%s = %s", nextIndentation, key, literalToStableHLO(value)) + } + w("\n%s}", indentation) + } +} + // hasToStableHLO is implemented by types that can be converted to a stablehlo string. type hasToStableHLO interface { ToStableHLO() string diff --git a/value.go b/value.go index c14dff8..5e48eee 100644 --- a/value.go +++ b/value.go @@ -18,9 +18,10 @@ import ( // // It also carries its shape information. type Value struct { - fn *Function - name string - shape shapes.Shape + fn *Function + name string + shape shapes.Shape + Attributes map[string]any } // Shape returns the shape of the value. From 11f22da2fbfa3e355f7da2e85484e3100c8216c1 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 16:12:51 +0100 Subject: [PATCH 10/30] Added Shardy execution support. --- function.go | 26 ++++++++--- stablehlo_test.go | 30 ++++++++----- statement.go | 11 ++--- tests/gopjrt/shardy_test.go | 85 ++++++++++++++++++++++++++++++++++++ types/shardy/shardingspec.go | 49 +++++++++++++++++++++ 5 files changed, 178 insertions(+), 23 deletions(-) create mode 100644 tests/gopjrt/shardy_test.go diff --git a/function.go b/function.go index 7c78c74..3801df2 100644 --- a/function.go +++ b/function.go @@ -160,7 +160,7 @@ func (fn *Function) NamedInputWithShardingAndAttributes(name string, shape shape if value.Attributes == nil { value.Attributes = make(map[string]any) } - value.Attributes["sdy.sharding"] = shardingSpec + value.Attributes["sdy.sharding"] = literalStr(shardingSpec.ToValueAttribute(value.shape)) if shardingSpec.Mesh != fn.Builder.mesh { return nil, errors.Errorf("sharding spec mesh %s doesn't match the stablehlo.Builder mesh %s", shardingSpec.Mesh, fn.Builder.mesh) @@ -249,15 +249,29 @@ func (fn *Function) Return(values ...*Value) error { return fn.ReturnWithAttributes(values, attributes) } -// ReturnWithSharding is a convenience function to call ReturnWithAttributes with the given sharding specifications. -func (fn *Function) ReturnWithSharding(values []*Value, shardingSpecs []*shardy.ShardingSpec) error { +// ReturnWithShardingAndAttributes is a convenience function to call ReturnWithAttributes with the given sharding +// specifications. +// +// The shardingSpecs slice of ShardingSpecs must have the same length as the values slice. +// Each ShardingSpec can be nil, in which case the default sharding is replicated across all devices. +// +// The attributes slice of maps can be set to nil, if there are no attributes. +func (fn *Function) ReturnWithShardingAndAttributes(values []*Value, shardingSpecs []*shardy.ShardingSpec, + attributes []map[string]any) error { if len(values) != len(shardingSpecs) { - return errors.Errorf("Function.ReturnWithSharding requires the same number of values and sharding specs, got %d and %d", len(values), len(shardingSpecs)) + return errors.Errorf("Function.ReturnWithShardingAndAttributes requires the same number of values and sharding specs, got %d and %d", len(values), len(shardingSpecs)) + } + if len(attributes) == 0 { + attributes = make([]map[string]any, len(values)) } - attributes := make([]map[string]any, len(values)) for i, shardingSpec := range shardingSpecs { if shardingSpec != nil { - attributes[i] = map[string]any{"sdy.sharding": shardingSpec} + specLiteral := literalStr(shardingSpec.ToValueAttribute(values[i].shape)) + if attributes[i] == nil { + attributes[i] = map[string]any{"sdy.sharding": specLiteral} + } else { + attributes[i]["sdy.sharding"] = specLiteral + } } } return fn.ReturnWithAttributes(values, attributes) diff --git a/stablehlo_test.go b/stablehlo_test.go index 56851b0..5def2c7 100644 --- a/stablehlo_test.go +++ b/stablehlo_test.go @@ -45,7 +45,7 @@ func TestBuilder(t *testing.T) { t.Run("Sharding", func(t *testing.T) { b := New(t.Name()) - mesh, err := shardy.NewDeviceMesh("mesh", []int{1, 1}, []string{"data", "model"}) + mesh, err := shardy.NewDeviceMesh("mesh", []int{4, 2}, []string{"data", "model"}) require.NoError(t, err) b.WithShardy(mesh) fn := b.Main() @@ -53,26 +53,35 @@ func TestBuilder(t *testing.T) { arg0 := must(fn.NamedInputWithShardingAndAttributes( "arg0", shapes.Make(dtypes.F32, 16, 128), - shardy.NewShardingSpec(mesh).AddShardedAxis("data"), + b.NewShardingSpec().AddShardedAxis("data"), nil, )) arg1 := must(fn.NamedInputWithSharding( "arg1", shapes.Make(dtypes.F32, 128, 256), - shardy.NewShardingSpec(mesh).AddShardedAxis("model"), + b.NewShardingSpec().AddShardedAxis("model"), )) tanh := must(Tanh(arg0)) dot := must(Dot(tanh, arg1)) - require.NoError(t, fn.ReturnWithAttributes( + err = fn.ReturnWithShardingAndAttributes( []*Value{dot}, - []map[string]any{{"jax.result_info": "result"}})) + []*shardy.ShardingSpec{ + b.NewShardingSpec().AddShardedAxis("data"), + }, + []map[string]any{ + {"jax.result_info": "result"}, + }) + require.NoError(t, err) program := string(must(b.Build())) fmt.Printf("%s program:\n%s", t.Name(), program) - want := `module @TestBuilder_Sharding attributes {stablehlo.num_replicas = 1, stablehlo.num_partitions = 1} { - sdy.mesh @mesh = <["data"=1, "model"=1]> - func.func @main(%arg0: tensor<16x128xf32> { sdy.sharding = #sdy.sharding<@mesh, [{data}], replicated={model}> }, %arg1: tensor<128x256xf32> { sdy.sharding = #sdy.sharding<@mesh, [{model}], replicated={data}> }) -> tensor<16x256xf32> { jax.result_info = "result" } { + want := `module @TestBuilder_Sharding attributes {stablehlo.num_replicas = 1, stablehlo.num_partitions = 8} { + sdy.mesh @mesh = <["data"=4, "model"=2]> + func.func @main(%arg0: tensor<16x128xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> }, %arg1: tensor<128x256xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"model"}, {}]> }) -> tensor<16x256xf32> { + jax.result_info = "result", + sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> + } { %0 = "stablehlo.tanh"(%arg0) : (tensor<16x128xf32>) -> tensor<16x128xf32> %1 = "stablehlo.dot_general"(%0, %arg1) { dot_dimension_numbers = #stablehlo.dot< @@ -87,10 +96,7 @@ func TestBuilder(t *testing.T) { } } ` - if program != want { - fmt.Printf(" Failed. Wanted the following program:\n%s", want) - t.Fatal("programs don't match") - } + require.Equal(t, want, program) }) t.Run("with inputs", func(t *testing.T) { diff --git a/statement.go b/statement.go index 2d75ccb..66561fe 100644 --- a/statement.go +++ b/statement.go @@ -3,6 +3,7 @@ package stablehlo import ( "fmt" "io" + "maps" "math" "reflect" "slices" @@ -156,13 +157,13 @@ func writeAttributes(writer io.Writer, indentation string, attributes map[string } else { // One attribute per line: w(" {") - first := true - for key, value := range attributes { - if !first { + keys := slices.Collect(maps.Keys(attributes)) + slices.Sort(keys) + for i, key := range keys { + if i > 0 { w(",") } - first = false - w("\n%s%s = %s", nextIndentation, key, literalToStableHLO(value)) + w("\n%s%s = %s", nextIndentation, key, literalToStableHLO(attributes[key])) } w("\n%s}", indentation) } diff --git a/tests/gopjrt/shardy_test.go b/tests/gopjrt/shardy_test.go new file mode 100644 index 0000000..65353d9 --- /dev/null +++ b/tests/gopjrt/shardy_test.go @@ -0,0 +1,85 @@ +package gopjrt + +import ( + "fmt" + "testing" + + "github.com/gomlx/gopjrt/dtypes" + "github.com/gomlx/gopjrt/pjrt" + "github.com/gomlx/stablehlo" + "github.com/gomlx/stablehlo/types/shapes" + "github.com/gomlx/stablehlo/types/shardy" + "github.com/stretchr/testify/require" +) + +func TestShardy(t *testing.T) { + iterateClientsAndTest(t, testShardy) +} + +// compileAndExecute program with PJRT. All inputs are donated. +func shardyCompileAndExecute(t *testing.T, client *pjrt.Client, program []byte, + mesh *shardy.DeviceMesh, inputs ...*pjrt.Buffer) []*pjrt.Buffer { + loadedExec, err := client.Compile(). + WithStableHLO(program). + WithShardy(mesh.NumDevices()). + WithDeviceAssignment(mesh.DeviceAssignment()). + Done() + require.NoErrorf(t, err, "failed to compile program: \n%s", program) + defer func() { + err := loadedExec.Destroy() + if err != nil { + t.Errorf("failed to destroy loaded exec: %+v", err) + } + }() + outputBuffers, err := loadedExec.Execute(inputs...).DonateAll().Done() + require.NoErrorf(t, err, "failed to execute program: \n%s", program) + return outputBuffers +} + +func testShardy(t *testing.T, client *pjrt.Client) { + // We will test it with 2 devices. + const numReplicas = 2 + numDevices := client.NumDevices() + if numDevices < numReplicas { + t.Skipf("Skipping test: not enough devices: %d < %d", numDevices, numReplicas) + return + } + + t.Run("input-data-sharding", func(t *testing.T) { + mesh := must1(shardy.NewDeviceMesh("data_mesh", []int{2}, []string{"data"})) + builder := stablehlo.New(t.Name()).WithShardy(mesh) + fn := builder.NewFunction("main") + x := must1(fn.NamedInputWithSharding("arg0", shapes.Make(dtypes.F32, 2, 3), + builder.NewShardingSpec().AddShardedAxis("data"))) + reductionFn := fn.Closure() + lhs := must1(reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32))) + rhs := must1(reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32))) + must(reductionFn.Return(must1(stablehlo.Add(lhs, rhs)))) + zero := must1(fn.ConstantFromScalar(float32(0))) + output := must1(stablehlo.Reduce(x, zero, reductionFn, 0, 1)) + must(fn.Return(output)) + program := must1(builder.Build()) + fmt.Printf("%s program:\n%s", t.Name(), program) + program = []byte(`module @TestShardy_input_data_sharding attributes {mhlo.num_replicas = 2:i32, mhlo.num_partitions = 1:i32} { + sdy.mesh @data_mesh = <["data"=2]> + func.func @main(%arg0: tensor<2x3xf32> { sdy.sharding = #sdy.sharding<@data_mesh, [{"data"}, {}]> }) -> tensor { + %1 = "stablehlo.constant"() { value = dense<0.0> : tensor } : () -> tensor + %2 = "stablehlo.reduce"(%arg0, %1) ({ + ^reductionFn(%lhs: tensor, %rhs: tensor) : + %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { dimensions = array } : (tensor<2x3xf32>, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + } +}`) + x0 := must1(client.BufferFromHost().ToDeviceNum(0).FromFlatDataWithDimensions( + []float32{0, 1, 2}, []int{1, 3}).Done()) + x1 := must1(client.BufferFromHost().ToDeviceNum(1).FromFlatDataWithDimensions( + []float32{0, 0.1, 0.2}, []int{1, 3}).Done()) + outputs := shardyCompileAndExecute(t, client, program, mesh, x0, x1) + requireBuffersEqual(t, []FlatAndDims{ + {[]float32{3.3}, nil}, + {[]float32{3.3}, nil}, + }, outputs) + }) +} diff --git a/types/shardy/shardingspec.go b/types/shardy/shardingspec.go index be6cf2d..8e90bf4 100644 --- a/types/shardy/shardingspec.go +++ b/types/shardy/shardingspec.go @@ -188,3 +188,52 @@ func (s *ShardingSpec) ToStableHLO() string { } return fmt.Sprintf("#sdy.sharding<@%s, [%s]%s>", s.Mesh.Name(), strings.Join(dimShardings, ", "), replicatedPart) } + +// ToValueAttribute converts the ShardingSpec to a StableHLO attribute of value with the given shape. +// +// Notice the rank of the ShardingSpec may be smaller than the rank of shape, in which case the extra axes are +// assumed to be replicated (empty). +// +// E.g.: "#sdy.sharding<@mesh, [{\"data\"}, {}]>" +func (s *ShardingSpec) ToValueAttribute(shape shapes.Shape) string { + var buf strings.Builder + w := func(format string, args ...any) { + buf.WriteString(fmt.Sprintf(format, args...)) + } + w("#sdy.sharding<@%s, [", s.Mesh.Name()) + for axisIdx := range shape.Rank() { + if axisIdx > 0 { + w(", ") + } + if axisIdx >= len(s.Axes) { + w("{}") + continue + } + tensorAxisSpec := s.Axes[axisIdx] + if len(tensorAxisSpec.MeshAxes) == 0 { + if tensorAxisSpec.Opened { + w("{?}") + } else { + w("{}") + } + continue + } + w("{") + for meshAxisIdx, meshAxisSpec := range tensorAxisSpec.MeshAxes { + if meshAxisIdx > 0 { + w(", ") + } + if meshAxisSpec.Size > 0 { + w("\"%s\":(%d)%d", meshAxisSpec.AxisName, meshAxisSpec.PreSize, meshAxisSpec.Size) + } else { + w("\"%s\"", meshAxisSpec.AxisName) + } + } + if tensorAxisSpec.Opened { + w(", ?") + } + w("}") + } + w("]>") + return buf.String() +} From 0f0cb8cf222580e9e9a43a6dfce8d7eda2d62ea1 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 16:49:08 +0100 Subject: [PATCH 11/30] Working Shardy tests. --- tests/gopjrt/shardy_test.go | 39 +++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/tests/gopjrt/shardy_test.go b/tests/gopjrt/shardy_test.go index 65353d9..37490af 100644 --- a/tests/gopjrt/shardy_test.go +++ b/tests/gopjrt/shardy_test.go @@ -60,18 +60,6 @@ func testShardy(t *testing.T, client *pjrt.Client) { must(fn.Return(output)) program := must1(builder.Build()) fmt.Printf("%s program:\n%s", t.Name(), program) - program = []byte(`module @TestShardy_input_data_sharding attributes {mhlo.num_replicas = 2:i32, mhlo.num_partitions = 1:i32} { - sdy.mesh @data_mesh = <["data"=2]> - func.func @main(%arg0: tensor<2x3xf32> { sdy.sharding = #sdy.sharding<@data_mesh, [{"data"}, {}]> }) -> tensor { - %1 = "stablehlo.constant"() { value = dense<0.0> : tensor } : () -> tensor - %2 = "stablehlo.reduce"(%arg0, %1) ({ - ^reductionFn(%lhs: tensor, %rhs: tensor) : - %0 = "stablehlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - "stablehlo.return"(%0) : (tensor) -> () - }) { dimensions = array } : (tensor<2x3xf32>, tensor) -> tensor - "stablehlo.return"(%2) : (tensor) -> () - } -}`) x0 := must1(client.BufferFromHost().ToDeviceNum(0).FromFlatDataWithDimensions( []float32{0, 1, 2}, []int{1, 3}).Done()) x1 := must1(client.BufferFromHost().ToDeviceNum(1).FromFlatDataWithDimensions( @@ -82,4 +70,31 @@ func testShardy(t *testing.T, client *pjrt.Client) { {[]float32{3.3}, nil}, }, outputs) }) + + t.Run("output-data-sharding", func(t *testing.T) { + mesh := must1(shardy.NewDeviceMesh("data_mesh", []int{2}, []string{"data"})) + builder := stablehlo.New(t.Name()).WithShardy(mesh) + fn := builder.NewFunction("main") + x := must1(fn.NamedInputWithSharding("arg0", shapes.Make(dtypes.F32, 2, 3), + builder.NewShardingSpec().AddShardedAxis("data"))) + reductionFn := fn.Closure() + lhs := must1(reductionFn.NamedInput("lhs", shapes.Make(dtypes.F32))) + rhs := must1(reductionFn.NamedInput("rhs", shapes.Make(dtypes.F32))) + must(reductionFn.Return(must1(stablehlo.Add(lhs, rhs)))) + zero := must1(fn.ConstantFromScalar(float32(0))) + output := must1(stablehlo.Reduce(x, zero, reductionFn, 1)) + must(fn.Return(output)) + program := must1(builder.Build()) + fmt.Printf("%s program:\n%s", t.Name(), program) + x0 := must1(client.BufferFromHost().ToDeviceNum(0).FromFlatDataWithDimensions( + []float32{0, 1, 2}, []int{1, 3}).Done()) + x1 := must1(client.BufferFromHost().ToDeviceNum(1).FromFlatDataWithDimensions( + []float32{0, 0.1, 0.2}, []int{1, 3}).Done()) + outputs := shardyCompileAndExecute(t, client, program, mesh, x0, x1) + requireBuffersEqual(t, []FlatAndDims{ + {[]float32{3}, []int{1}}, + {[]float32{0.3}, []int{1}}, + }, outputs) + }) + } From 9e85a69d7cfb3a251ae3ee75cbbb932606192bdd Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 19 Nov 2025 16:52:33 +0100 Subject: [PATCH 12/30] Updated CHANGELOG. Improved test. --- docs/CHANGELOG.md | 5 +++++ tests/gopjrt/shardy_test.go | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 044eda8..a5ea621 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -2,6 +2,11 @@ - `Function.Input` and `Function.NamedInput`: (change in API) they now may return an error, if the name is duplicate. - `AllReduce` now supports arbitrary number of inputs, to be reduced at once. +- Added XLA Shardy support: + - Added `shardy.DeviceMesh` and `shardy.ShardingSpec` types. + - Added `Builder.WithShardy(mesh)` + - Added `Function.NamedInputWithShardingAndAttributes()` + - Added `Function.ReturnWithShardingAndAttributes()` # v0.1.0: 2025/11/06 Multi-Device support diff --git a/tests/gopjrt/shardy_test.go b/tests/gopjrt/shardy_test.go index 37490af..e79dbad 100644 --- a/tests/gopjrt/shardy_test.go +++ b/tests/gopjrt/shardy_test.go @@ -48,7 +48,7 @@ func testShardy(t *testing.T, client *pjrt.Client) { t.Run("input-data-sharding", func(t *testing.T) { mesh := must1(shardy.NewDeviceMesh("data_mesh", []int{2}, []string{"data"})) builder := stablehlo.New(t.Name()).WithShardy(mesh) - fn := builder.NewFunction("main") + fn := builder.Main() x := must1(fn.NamedInputWithSharding("arg0", shapes.Make(dtypes.F32, 2, 3), builder.NewShardingSpec().AddShardedAxis("data"))) reductionFn := fn.Closure() From 4d6adcb4645848ccee604258bc97ecf5bd2bb5e3 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 06:50:26 +0100 Subject: [PATCH 13/30] Renamed DeviceMesh.shape to DeviceMesh.axesSizes --- types/shardy/devicemesh.go | 54 ++++++++++++++++----------------- types/shardy/devicemesh_test.go | 10 +++--- types/shardy/shardingspec.go | 8 ++--- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/types/shardy/devicemesh.go b/types/shardy/devicemesh.go index 96660ad..a1f589d 100644 --- a/types/shardy/devicemesh.go +++ b/types/shardy/devicemesh.go @@ -20,8 +20,8 @@ type DeviceMesh struct { // axesNames are the names of the mesh axes. axesNames []string - // shape defines the number of devices along each mesh axis. - shape []int + // axesSizes defines the number of devices along each mesh axis. + axesSizes []int // nameToAxis maps axis names to their index. nameToAxis map[string]int @@ -39,7 +39,7 @@ type DeviceMesh struct { // NewDeviceMesh creates a new logical topology of a set of devices. // // - name: the name of the mesh, it must be a valid StableHLO identifier (see stablehlo.NormalizeIdentifier). -// - shape: defines the number of devices along each mesh axis, one value per axis. +// - axesSizes: defines the number of devices along each mesh axis, one value per axis. // - axesNames: the names of the mesh axes. One value per axis. They must also be valid StableHLO identifiers // (see stablehlo.NormalizeName). // @@ -50,11 +50,11 @@ type DeviceMesh struct { // with the DeviceMesh.WithDeviceMapping() method. func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, error) { if len(shape) != len(axisNames) { - return nil, errors.Errorf("shape and axesNames must have the same length, got %d and %d", + return nil, errors.Errorf("axesSizes and axesNames must have the same length, got %d and %d", len(shape), len(axisNames)) } if len(shape) == 0 { - return nil, errors.New("DeviceMesh shape cannot be empty") + return nil, errors.New("DeviceMesh axesSizes cannot be empty") } // Normalize names: @@ -88,7 +88,7 @@ func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, e m := &DeviceMesh{ name: name, axesNames: axisNames, - shape: shape, + axesSizes: shape, nameToAxis: nameToAxis, numDevices: numDevices, deviceAssignment: make([]int, numDevices), @@ -118,7 +118,7 @@ func (m *DeviceMesh) NumDevices() int { // Rank returns the number of axes in the mesh. func (m *DeviceMesh) Rank() int { - return len(m.shape) + return len(m.axesSizes) } // AxisNames returns a copy of the mesh's axis names. @@ -126,10 +126,10 @@ func (m *DeviceMesh) AxisNames() []string { return slices.Clone(m.axesNames) } -// Shape returns a copy of the mesh's shape. +// Shape returns a copy of the mesh's axesSizes. func (m *DeviceMesh) Shape() []int { - shape := make([]int, len(m.shape)) - copy(shape, m.shape) + shape := make([]int, len(m.axesSizes)) + copy(shape, m.axesSizes) return shape } @@ -139,18 +139,18 @@ func (m *DeviceMesh) AxisSize(axisName string) (int, error) { if !found { return 0, errors.Errorf("mesh axis %q not found", axisName) } - return m.shape[idx], nil + return m.axesSizes[idx], nil } // String implements the fmt.Stringer interface. func (m *DeviceMesh) String() string { var sb strings.Builder - sb.WriteString("DeviceMesh(shape={") + sb.WriteString("DeviceMesh(axesSizes={") for i, name := range m.axesNames { if i > 0 { sb.WriteString(", ") } - _, _ = fmt.Fprintf(&sb, "%s: %d", name, m.shape[i]) + _, _ = fmt.Fprintf(&sb, "%s: %d", name, m.axesSizes[i]) } sb.WriteString("})") return sb.String() @@ -195,11 +195,11 @@ func (m *DeviceMesh) DeviceToMesh(physicalDevice int) (flatIdx int, axisIndices } // Convert flat index to per-axis indices - axisIndices = make([]int, len(m.shape)) + axisIndices = make([]int, len(m.axesSizes)) remaining := flatIdx - for i := len(m.shape) - 1; i >= 0; i-- { - axisIndices[i] = remaining % m.shape[i] - remaining /= m.shape[i] + for i := len(m.axesSizes) - 1; i >= 0; i-- { + axisIndices[i] = remaining % m.axesSizes[i] + remaining /= m.axesSizes[i] } return flatIdx, axisIndices, nil } @@ -233,8 +233,8 @@ func (m *DeviceMesh) ComputeReplicaGroups(axes []string) ([][]int, error) { } // Create indices for each axis dimension - nonAxisIndices := make([]int, 0, len(m.shape)-len(axisIndices)) - for i := range m.shape { + nonAxisIndices := make([]int, 0, len(m.axesSizes)-len(axisIndices)) + for i := range m.axesSizes { if !slices.Contains(axisIndices, i) { nonAxisIndices = append(nonAxisIndices, i) } @@ -243,7 +243,7 @@ func (m *DeviceMesh) ComputeReplicaGroups(axes []string) ([][]int, error) { // Calculate the size of each group and number of groups groupSize := 1 for _, idx := range axisIndices { - groupSize *= m.shape[idx] + groupSize *= m.axesSizes[idx] } numGroups := m.numDevices / groupSize @@ -256,11 +256,11 @@ func (m *DeviceMesh) ComputeReplicaGroups(axes []string) ([][]int, error) { // Fill in the groups for flatIdx := 0; flatIdx < m.numDevices; flatIdx++ { // Convert flat index to per-axis indices - indices := make([]int, len(m.shape)) + indices := make([]int, len(m.axesSizes)) remaining := flatIdx - for i := len(m.shape) - 1; i >= 0; i-- { - indices[i] = remaining % m.shape[i] - remaining /= m.shape[i] + for i := len(m.axesSizes) - 1; i >= 0; i-- { + indices[i] = remaining % m.axesSizes[i] + remaining /= m.axesSizes[i] } // Calculate group index from non-axis indices @@ -269,7 +269,7 @@ func (m *DeviceMesh) ComputeReplicaGroups(axes []string) ([][]int, error) { for i := len(nonAxisIndices) - 1; i >= 0; i-- { axisIdx := nonAxisIndices[i] groupIdx += indices[axisIdx] * multiplier - multiplier *= m.shape[axisIdx] + multiplier *= m.axesSizes[axisIdx] } // Calculate position within group from axis indices @@ -278,7 +278,7 @@ func (m *DeviceMesh) ComputeReplicaGroups(axes []string) ([][]int, error) { for i := len(axisIndices) - 1; i >= 0; i-- { axisIdx := axisIndices[i] posInGroup += indices[axisIdx] * multiplier - multiplier *= m.shape[axisIdx] + multiplier *= m.axesSizes[axisIdx] } groups[groupIdx][posInGroup] = flatIdx @@ -299,7 +299,7 @@ func (m *DeviceMesh) ToStableHLO() string { if i > 0 { w(", ") } - w("%q=%d", axisName, m.shape[i]) + w("%q=%d", axisName, m.axesSizes[i]) } w("]>") return buf.String() diff --git a/types/shardy/devicemesh_test.go b/types/shardy/devicemesh_test.go index 26574c3..634cd20 100644 --- a/types/shardy/devicemesh_test.go +++ b/types/shardy/devicemesh_test.go @@ -75,13 +75,13 @@ func TestDeviceMesh(t *testing.T) { name: "mismatched lengths", shape: []int{2, 4}, axisNames: []string{"x"}, - wantErr: "shape and axesNames must have the same length", + wantErr: "axesSizes and axesNames must have the same length", }, { - name: "empty shape", + name: "empty axesSizes", shape: []int{}, axisNames: []string{}, - wantErr: "DeviceMesh shape cannot be empty", + wantErr: "DeviceMesh axesSizes cannot be empty", }, { name: "empty axis name", @@ -186,13 +186,13 @@ func TestDeviceMesh(t *testing.T) { name: "1D mesh", shape: []int{8}, axisNames: []string{"replica"}, - want: "DeviceMesh(shape={replica: 8})", + want: "DeviceMesh(axesSizes={replica: 8})", }, { name: "2D mesh", shape: []int{2, 4}, axisNames: []string{"x", "y"}, - want: "DeviceMesh(shape={x: 2, y: 4})", + want: "DeviceMesh(axesSizes={x: 2, y: 4})", }, } diff --git a/types/shardy/shardingspec.go b/types/shardy/shardingspec.go index 8e90bf4..c5b4751 100644 --- a/types/shardy/shardingspec.go +++ b/types/shardy/shardingspec.go @@ -117,7 +117,7 @@ func (s *ShardingSpec) Validate() error { return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d refers to unknown mesh axis %q", i, j, axisName) } - meshAxisSize := s.Mesh.shape[axisIdx] + meshAxisSize := s.Mesh.axesSizes[axisIdx] // Check sub-axis specification. if meshAxisSpec.Size > 0 { @@ -145,7 +145,7 @@ func (s *ShardingSpec) ValidateShape(shape shapes.Shape) error { return err } if s.Rank() > shape.Rank() { - return errors.Errorf("ShardingSpec shape rank %d is largers than tensor rank %d", s.Rank(), shape.Rank()) + return errors.Errorf("ShardingSpec shape rank %d is larger than tensor rank %d", s.Rank(), shape.Rank()) } return nil } @@ -189,9 +189,9 @@ func (s *ShardingSpec) ToStableHLO() string { return fmt.Sprintf("#sdy.sharding<@%s, [%s]%s>", s.Mesh.Name(), strings.Join(dimShardings, ", "), replicatedPart) } -// ToValueAttribute converts the ShardingSpec to a StableHLO attribute of value with the given shape. +// ToValueAttribute converts the ShardingSpec to a StableHLO attribute of a value with the given shape. // -// Notice the rank of the ShardingSpec may be smaller than the rank of shape, in which case the extra axes are +// Notice the rank of the ShardingSpec may be smaller than the rank of the shape, in which case the extra axes are // assumed to be replicated (empty). // // E.g.: "#sdy.sharding<@mesh, [{\"data\"}, {}]>" From 4a90ce43c8d15f020e0de46b25b898244f2ca64d Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 07:00:43 +0100 Subject: [PATCH 14/30] Cosmetic. --- types/shardy/shardingspec.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/types/shardy/shardingspec.go b/types/shardy/shardingspec.go index c5b4751..44b2f3c 100644 --- a/types/shardy/shardingspec.go +++ b/types/shardy/shardingspec.go @@ -110,11 +110,13 @@ func (s *ShardingSpec) Validate() error { for j, meshAxisSpec := range axisSpec.MeshAxes { axisName := meshAxisSpec.AxisName if axisName == "" { - return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d refers to empty mesh axis name", i, j) + return errors.Errorf( + "ShardingSpec tensor axis %d, mesh axis #%d refers to empty mesh axis name", i, j) } axisIdx, ok := s.Mesh.nameToAxis[axisName] if !ok { - return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d refers to unknown mesh axis %q", + return errors.Errorf( + "ShardingSpec tensor axis %d, mesh axis #%d refers to unknown mesh axis %q", i, j, axisName) } meshAxisSize := s.Mesh.axesSizes[axisIdx] @@ -126,7 +128,9 @@ func (s *ShardingSpec) Validate() error { i, j, axisName, meshAxisSpec.PreSize) } if meshAxisSize%(meshAxisSpec.PreSize*meshAxisSpec.Size) != 0 { - return errors.Errorf("ShardingSpec tensor axis %d, mesh axis #%d %q with PreSize %d and Size %d is not compatible with mesh axis of size %d", + return errors.Errorf( + "ShardingSpec tensor axis %d, mesh axis #%d %q with PreSize %d and Size %d is not "+ + "compatible with mesh axis of size %d", i, j, axisName, meshAxisSpec.PreSize, meshAxisSpec.Size, meshAxisSize) } } From fb0abb619a02b080c7295ac94094860697431348 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 09:01:53 +0100 Subject: [PATCH 15/30] Fixed documentation. --- types/shardy/shardingspec.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/types/shardy/shardingspec.go b/types/shardy/shardingspec.go index 44b2f3c..4f2b17f 100644 --- a/types/shardy/shardingspec.go +++ b/types/shardy/shardingspec.go @@ -12,7 +12,7 @@ import ( // ShardingSpec (also known as PartitionSpec in JAX) defines how a logical tensor is to be sharded (partitioned) across // a DeviceMesh. This is used by Shardy, and is based on its documentation in [1]. // -// The definition is per axis of the logical tensor -- and not per axis of the Mesh, common confusion. +// The definition is per axis of the logical tensor -- and not per axis of the Mesh, a common confusion. // If not all axes of the Tensor are defined, the tail axes are considered simply to be replicated across the whole // mesh. // @@ -23,13 +23,13 @@ import ( // mesh := NewDeviceMesh("my_mesh", []int{2, 2}, []string{"data", "model"}) // // // Input's "batch" axis is sharded across the "data" axis of the mesh. -// inputSharding := MakeShardSpec(mesh.Name()).AddShardedAxis("data") +// inputSharding := NewShardingSpec(mesh).AddShardedAxis("data") // // // First axis is replicated, second is shared across "model" devices -// variableSharding := MakeShardSpec(mesh.Name()).AddReplicated().AddShardedAxis("model") +// variableSharding := NewShardingSpec(mesh).AddReplicated().AddShardedAxis("model") // // // Second axis is sharded across both "data" and "model" devices. -// largeWeights := MakeShardSpec(mesh.Name()).AddReplicated().AddShardedAxis("data", "model") +// largeWeights := NewShardingSpec(mesh).AddReplicated().AddShardedAxis("data", "model") // // There are two advanced features supported but not tested (pls if you need let us know how it goes, or if you find // any issues): @@ -47,7 +47,7 @@ type ShardingSpec struct { // TensorAxisSpec specifies how a tensor axis is to be sharded (or replicated). // See details in ShardingSpec. // -// Usually, one would create this using ShardingSpec.AddAxis or ShardingSpec.AddReplicated +// Usually, one would create this using ShardingSpec.AddShardedAxis or ShardingSpec.AddReplicated type TensorAxisSpec struct { MeshAxes []MeshAxisSpec Opened bool // If opened to further sharding. From 951eab5fce8194bcf72505d37cfc2bbe24002479 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 09:17:20 +0100 Subject: [PATCH 16/30] DeviceMesh changed to take a "logical device assignment". --- tests/gopjrt/shardy_test.go | 2 +- types/shardy/devicemesh.go | 85 +++++++--------------- types/shardy/devicemesh_test.go | 120 +++----------------------------- 3 files changed, 38 insertions(+), 169 deletions(-) diff --git a/tests/gopjrt/shardy_test.go b/tests/gopjrt/shardy_test.go index e79dbad..9262722 100644 --- a/tests/gopjrt/shardy_test.go +++ b/tests/gopjrt/shardy_test.go @@ -22,7 +22,7 @@ func shardyCompileAndExecute(t *testing.T, client *pjrt.Client, program []byte, loadedExec, err := client.Compile(). WithStableHLO(program). WithShardy(mesh.NumDevices()). - WithDeviceAssignment(mesh.DeviceAssignment()). + WithDeviceAssignment(mesh.LogicalDeviceAssignment()). Done() require.NoErrorf(t, err, "failed to compile program: \n%s", program) defer func() { diff --git a/types/shardy/devicemesh.go b/types/shardy/devicemesh.go index a1f589d..8978ce1 100644 --- a/types/shardy/devicemesh.go +++ b/types/shardy/devicemesh.go @@ -29,11 +29,10 @@ type DeviceMesh struct { // numDevices is the total number of devices in the mesh. numDevices int - // deviceAssignment is the list of devices numbers in the mesh, in the order they appear in the mesh. - deviceAssignment []int - - // physicalDeviceMapping is the mapping of concrete devices to the flat index in the mesh. - physicalDeviceMapping map[int]int + // logicalDeviceAssignment is the list of "logical" devices numbers in the mesh, in the order they appear in the + // mesh. + // These numbers are indices in the LogicalDeviceAssignment that will be used in the compilation of the program. + logicalDeviceAssignment []int } // NewDeviceMesh creates a new logical topology of a set of devices. @@ -43,11 +42,8 @@ type DeviceMesh struct { // - axesNames: the names of the mesh axes. One value per axis. They must also be valid StableHLO identifiers // (see stablehlo.NormalizeName). // -// The default mapping of concrete devices numbers to the mesh is sequential, starting from 0, but it can be -// changed with the DeviceMesh.SetDeviceAssignment() method. -// -// For non-symmetric devices, where the connection speed among the devices matters, a custom mapping can be provided -// with the DeviceMesh.WithDeviceMapping() method. +// The default mapping of logical devices numbers to the mesh is sequential, starting from 0, but it can be +// changed with the DeviceMesh.SetLogicalDeviceAssignment() method. func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, error) { if len(shape) != len(axisNames) { return nil, errors.Errorf("axesSizes and axesNames must have the same length, got %d and %d", @@ -86,17 +82,16 @@ func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, e } m := &DeviceMesh{ - name: name, - axesNames: axisNames, - axesSizes: shape, - nameToAxis: nameToAxis, - numDevices: numDevices, - deviceAssignment: make([]int, numDevices), + name: name, + axesNames: axisNames, + axesSizes: shape, + nameToAxis: nameToAxis, + numDevices: numDevices, + logicalDeviceAssignment: make([]int, numDevices), } - for i := range m.deviceAssignment { - m.deviceAssignment[i] = i + for i := range m.logicalDeviceAssignment { + m.logicalDeviceAssignment[i] = i } - m.buildPhysicalDeviceMapping() return m, nil } @@ -104,13 +99,6 @@ func (m *DeviceMesh) Name() string { return m.name } -func (m *DeviceMesh) buildPhysicalDeviceMapping() { - m.physicalDeviceMapping = make(map[int]int, m.numDevices) - for i, device := range m.deviceAssignment { - m.physicalDeviceMapping[device] = i - } -} - // NumDevices returns the total number of devices in the mesh. func (m *DeviceMesh) NumDevices() int { return m.numDevices @@ -156,10 +144,12 @@ func (m *DeviceMesh) String() string { return sb.String() } -// SetDeviceAssignment sets the assignment of concrete devices to the mesh. +// SetLogicalDeviceAssignment sets the assignment of logical devices to the mesh. +// +// The length of devices must be equal to NumDevices(). And it should include all numbers from 0 to NumDevices()-1. // -// It returns an error if deviceAssignment has invalid device numbers or len(devices) != NumDevices(). -func (m *DeviceMesh) SetDeviceAssignment(devices ...int) error { +// It returns an error if logicalDeviceAssignment has invalid device numbers or len(devices) != NumDevices(). +func (m *DeviceMesh) SetLogicalDeviceAssignment(devices ...int) error { if len(devices) != m.numDevices { return errors.Errorf("devices must have %d elements, got %d", m.numDevices, len(devices)) } @@ -169,45 +159,24 @@ func (m *DeviceMesh) SetDeviceAssignment(devices ...int) error { return errors.Errorf("physical device #%d is duplicated in mapping", device) } seen.Insert(device) - if device < 0 { - return errors.Errorf("devices must be positive, got device %d", device) + if device < 0 || device >= m.numDevices { + return errors.Errorf("devices must be between 0 and %d (NumDevices()-1), got device %d", + m.numDevices-1, device) } } - copy(m.deviceAssignment, devices) - m.buildPhysicalDeviceMapping() - if len(m.physicalDeviceMapping) != m.numDevices { - return errors.Errorf("provided devicesIn: physicalDeviceMapping has %d elements, expected %d", len(m.physicalDeviceMapping), m.numDevices) - } + copy(m.logicalDeviceAssignment, devices) return nil } -// DeviceAssignment returns the list of devices in the mesh, in the order they appear in the mesh. -func (m *DeviceMesh) DeviceAssignment() []int { - return slices.Clone(m.deviceAssignment) -} - -// DeviceToMesh return the indices (flat and per-axis) assigned to the given physicalDevice. -func (m *DeviceMesh) DeviceToMesh(physicalDevice int) (flatIdx int, axisIndices []int, err error) { - var ok bool - flatIdx, ok = m.physicalDeviceMapping[physicalDevice] - if !ok { - return 0, nil, errors.Errorf("physical device %d is not part of the mesh", physicalDevice) - } - - // Convert flat index to per-axis indices - axisIndices = make([]int, len(m.axesSizes)) - remaining := flatIdx - for i := len(m.axesSizes) - 1; i >= 0; i-- { - axisIndices[i] = remaining % m.axesSizes[i] - remaining /= m.axesSizes[i] - } - return flatIdx, axisIndices, nil +// LogicalDeviceAssignment returns the list of devices in the mesh, in the order they appear in the mesh. +func (m *DeviceMesh) LogicalDeviceAssignment() []int { + return slices.Clone(m.logicalDeviceAssignment) } // ComputeReplicaGroups returns the replica groups participating in some collective (distributed) operation given the // axes along which the operation is performed. // -// Each replica group (a []int) includes the device indices (from the DeviceAssignment) for the axes specified. +// Each replica group (a []int) includes the device indices (from the LogicalDeviceAssignment) for the axes specified. // The other axes will be split into different replica groups. // // Example: diff --git a/types/shardy/devicemesh_test.go b/types/shardy/devicemesh_test.go index 634cd20..463252a 100644 --- a/types/shardy/devicemesh_test.go +++ b/types/shardy/devicemesh_test.go @@ -223,22 +223,14 @@ func TestDeviceMesh(t *testing.T) { }, { name: "custom mapping", - devices: []int{2, 5, 1, 7}, + devices: []int{2, 1, 3, 0}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := mesh.SetDeviceAssignment(tt.devices...) - require.NoError(t, err) - - // Verify mapping is applied correctly - for i, device := range tt.devices { - flatIdx, axisIndices, err := mesh.DeviceToMesh(device) - require.NoError(t, err) - assert.Equal(t, i, flatIdx) - assert.Equal(t, []int{i}, axisIndices) - } + err := mesh.SetLogicalDeviceAssignment(tt.devices...) + require.NoErrorf(t, err, "failed test %q", tt.name) }) } }) @@ -265,131 +257,39 @@ func TestDeviceMesh(t *testing.T) { { name: "device out of range (negative)", devices: []int{0, 1, -1, 3}, - wantErr: "devices must be positive", + wantErr: "devices must be between 0 and 3", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := mesh.SetDeviceAssignment(tt.devices...) + err := mesh.SetLogicalDeviceAssignment(tt.devices...) require.Error(t, err) assert.Contains(t, err.Error(), tt.wantErr) }) } }) - t.Run("DeviceToMesh_1D", func(t *testing.T) { - mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) - require.NoError(t, err) - - for i := 0; i < 4; i++ { - flatIdx, axisIndices, err := mesh.DeviceToMesh(int(i)) - require.NoError(t, err) - assert.Equal(t, i, flatIdx) - assert.Equal(t, []int{i}, axisIndices) - } - }) - t.Run("DeviceToMesh_2D", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) require.NoError(t, err) - - tests := []struct { - device int - wantFlat int - wantIndices []int - }{ - {device: 0, wantFlat: 0, wantIndices: []int{0, 0}}, - {device: 1, wantFlat: 1, wantIndices: []int{0, 1}}, - {device: 2, wantFlat: 2, wantIndices: []int{0, 2}}, - {device: 3, wantFlat: 3, wantIndices: []int{0, 3}}, - {device: 4, wantFlat: 4, wantIndices: []int{1, 0}}, - {device: 5, wantFlat: 5, wantIndices: []int{1, 1}}, - {device: 6, wantFlat: 6, wantIndices: []int{1, 2}}, - {device: 7, wantFlat: 7, wantIndices: []int{1, 3}}, - } - - for _, tt := range tests { - t.Run(string(rune(tt.device)), func(t *testing.T) { - flatIdx, axisIndices, err := mesh.DeviceToMesh(tt.device) - require.NoError(t, err) - assert.Equal(t, tt.wantFlat, flatIdx) - assert.Equal(t, tt.wantIndices, axisIndices) - }) - } + require.Equal(t, 8, mesh.NumDevices()) }) t.Run("DeviceToMesh_3D", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) require.NoError(t, err) - - tests := []struct { - device int - wantFlat int - wantIndices []int - }{ - {device: 0, wantFlat: 0, wantIndices: []int{0, 0, 0}}, - {device: 1, wantFlat: 1, wantIndices: []int{0, 0, 1}}, - {device: 2, wantFlat: 2, wantIndices: []int{0, 1, 0}}, - {device: 3, wantFlat: 3, wantIndices: []int{0, 1, 1}}, - {device: 4, wantFlat: 4, wantIndices: []int{1, 0, 0}}, - {device: 5, wantFlat: 5, wantIndices: []int{1, 0, 1}}, - {device: 6, wantFlat: 6, wantIndices: []int{1, 1, 0}}, - {device: 7, wantFlat: 7, wantIndices: []int{1, 1, 1}}, - } - - for _, tt := range tests { - t.Run(string(rune(tt.device)), func(t *testing.T) { - flatIdx, axisIndices, err := mesh.DeviceToMesh(tt.device) - require.NoError(t, err) - assert.Equal(t, tt.wantFlat, flatIdx) - assert.Equal(t, tt.wantIndices, axisIndices) - }) - } + require.Equal(t, 8, mesh.NumDevices()) }) t.Run("DeviceToMesh_WithCustomMapping", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) require.NoError(t, err) - - // Set custom mapping: devices [7, 5, 3, 1] - err = mesh.SetDeviceAssignment(7, 5, 3, 1) + err = mesh.SetLogicalDeviceAssignment(3, 2, 1, 0) require.NoError(t, err) - - tests := []struct { - device int - wantFlat int - wantIndices []int - }{ - {device: 7, wantFlat: 0, wantIndices: []int{0}}, - {device: 5, wantFlat: 1, wantIndices: []int{1}}, - {device: 3, wantFlat: 2, wantIndices: []int{2}}, - {device: 1, wantFlat: 3, wantIndices: []int{3}}, - } - - for _, tt := range tests { - t.Run(string(rune(tt.device)), func(t *testing.T) { - flatIdx, axisIndices, err := mesh.DeviceToMesh(tt.device) - require.NoError(t, err) - assert.Equal(t, tt.wantFlat, flatIdx) - assert.Equal(t, tt.wantIndices, axisIndices) - }) - } - - // Devices not in the mesh should error - _, _, err = mesh.DeviceToMesh(0) - require.Error(t, err) - assert.Contains(t, err.Error(), "not part of the mesh") - }) - - t.Run("DeviceToMesh_NotInMesh", func(t *testing.T) { - mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) - require.NoError(t, err) - - // Device 5 is not in the mesh (only 0-3 are used) - _, _, err = mesh.DeviceToMesh(5) + require.Equal(t, 4, mesh.NumDevices()) + err = mesh.SetLogicalDeviceAssignment(4, 2, 1, 0) require.Error(t, err) - assert.Contains(t, err.Error(), "physical device 5 is not part of the mesh") }) t.Run("ComputeReplicaGroups", func(t *testing.T) { From 23e49cd34511bcde8a9dd2f5c02dea89298bd81a Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 09:42:58 +0100 Subject: [PATCH 17/30] Added support for multiple meshes in StableHLO module. --- builder.go | 61 +++++++++++++++++++++++++++++++++++++++++------------ function.go | 12 ++++++++--- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/builder.go b/builder.go index 9f9369d..46c4a7f 100644 --- a/builder.go +++ b/builder.go @@ -6,6 +6,7 @@ import ( "io" "slices" + "github.com/gomlx/stablehlo/internal/utils" "github.com/gomlx/stablehlo/types" "github.com/gomlx/stablehlo/types/shardy" "github.com/pkg/errors" @@ -23,11 +24,12 @@ type Builder struct { // inlineUniqueID is a counter used to generate unique names for inlined functions values. inlineUniqueID int - // Mesh used for Shardy. - mesh *shardy.DeviceMesh + // meshes used for Shardy. + meshes []*shardy.DeviceMesh // numReplicas is the number of replicas for data parallelism. numReplicas int + // numPartitions is the number of partitions for model parallelism. numPartitions int @@ -74,29 +76,53 @@ func (b *Builder) WithNumPartitions(n int) *Builder { return b } -// WithShardy enables distributed computation across the devices selected by the mesh. +// WithShardy enables distributed computation across the devices selected by the given meshes. +// // This is the recommended way to do distributed (across devices) computation, and given the inputs // with sharded information, Shardy will automatically distribute the computation, without you needing // to specify any of the collective operations. // +// Usually, there is only one meshes. But one can split the devices in different meshes. The meshes overlap +// the concrete devices used. +// // See details of XLA Shardy in [1] // // [1] https://github.com/openxla/shardy -func (b *Builder) WithShardy(mesh *shardy.DeviceMesh) *Builder { - b.mesh = mesh +func (b *Builder) WithShardy(meshes ...*shardy.DeviceMesh) *Builder { + b.meshes = meshes b.WithNumReplicas(1) - b.WithNumPartitions(mesh.NumDevices()) + numDevices := 0 + for _, mesh := range meshes { + numDevices = max(numDevices, mesh.NumDevices()) + } + b.WithNumPartitions(numDevices) return b } -// Mesh returns the mesh configured with WithShardy. -func (b *Builder) Mesh() *shardy.DeviceMesh { - return b.mesh +// Meshes returns the meshes configured with WithShardy. +func (b *Builder) Meshes() []*shardy.DeviceMesh { + return b.meshes } -// NewShardingSpec creates a new ShardingSpec using the mesh configured with WithShardy. +// NewShardingSpec creates a new ShardingSpec using the first mesh configured with WithShardy. +// It returns nil if no mesh was not configured. +// +// This is a shortcut to NewShardingSpecByMeshIx(0). func (b *Builder) NewShardingSpec() *shardy.ShardingSpec { - return shardy.NewShardingSpec(b.mesh) + if len(b.meshes) == 0 { + return nil + } + return shardy.NewShardingSpec(b.meshes[0]) +} + +// NewShardingSpecByMeshIx creates a new ShardingSpec for the meshIdx (the order given by WithShardy). +// +// It may return nil if meshIdx is out of range. +func (b *Builder) NewShardingSpecByMeshIx(meshIdx int) *shardy.ShardingSpec { + if meshIdx < 0 || meshIdx >= len(b.meshes) { + return nil + } + return shardy.NewShardingSpec(b.meshes[meshIdx]) } // elementWriter represents elements of ToStableHLO that know how to write themselves. @@ -194,9 +220,16 @@ func (b *Builder) Write(writer io.Writer) error { } w(" {\n") - // Write Shardy mesh if needed: - if b.mesh != nil { - w("%s%s\n", IndentationStep, b.mesh.ToStableHLO()) + // Write Shardy meshes if needed: + if len(b.meshes) > 0 { + namesUsed := utils.MakeSet[string](len(b.meshes)) + for _, mesh := range b.meshes { + if namesUsed.Has(mesh.Name()) { + return errors.Errorf("duplicate mesh name %q", mesh.Name()) + } + namesUsed.Insert(mesh.Name()) + w("%s%s\n", IndentationStep, mesh.ToStableHLO()) + } } // Write non-inline functions: diff --git a/function.go b/function.go index 3801df2..0636291 100644 --- a/function.go +++ b/function.go @@ -4,7 +4,9 @@ import ( "fmt" "io" "reflect" + "slices" "strconv" + "strings" "github.com/gomlx/gopjrt/dtypes" "github.com/gomlx/stablehlo/internal/optypes" @@ -161,9 +163,13 @@ func (fn *Function) NamedInputWithShardingAndAttributes(name string, shape shape value.Attributes = make(map[string]any) } value.Attributes["sdy.sharding"] = literalStr(shardingSpec.ToValueAttribute(value.shape)) - if shardingSpec.Mesh != fn.Builder.mesh { - return nil, errors.Errorf("sharding spec mesh %s doesn't match the stablehlo.Builder mesh %s", - shardingSpec.Mesh, fn.Builder.mesh) + if slices.Index(fn.Builder.meshes, shardingSpec.Mesh) == -1 { + meshesNames := make([]string, len(fn.Builder.meshes)) + for _, mesh := range fn.Builder.meshes { + meshesNames = append(meshesNames, mesh.Name()) + } + return nil, errors.Errorf("sharding spec meshe %q doesn't match any of the stablehlo.Builder meshes (%s)", + shardingSpec.Mesh, strings.Join(meshesNames, ", ")) } if err := shardingSpec.ValidateShape(shape); err != nil { return nil, err From 1eb49422f67c9284b42c55c911e4ab35ad28e66c Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 10:25:07 +0100 Subject: [PATCH 18/30] Fixed test to use DeviceAssignment correctly. --- tests/gopjrt/shardy_test.go | 38 ++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/tests/gopjrt/shardy_test.go b/tests/gopjrt/shardy_test.go index e79dbad..dcc9578 100644 --- a/tests/gopjrt/shardy_test.go +++ b/tests/gopjrt/shardy_test.go @@ -18,11 +18,11 @@ func TestShardy(t *testing.T) { // compileAndExecute program with PJRT. All inputs are donated. func shardyCompileAndExecute(t *testing.T, client *pjrt.Client, program []byte, - mesh *shardy.DeviceMesh, inputs ...*pjrt.Buffer) []*pjrt.Buffer { + deviceAssignment []int, inputs ...*pjrt.Buffer) []*pjrt.Buffer { loadedExec, err := client.Compile(). WithStableHLO(program). - WithShardy(mesh.NumDevices()). - WithDeviceAssignment(mesh.DeviceAssignment()). + WithShardy(len(deviceAssignment)). + WithDeviceAssignment(deviceAssignment). Done() require.NoErrorf(t, err, "failed to compile program: \n%s", program) defer func() { @@ -44,6 +44,10 @@ func testShardy(t *testing.T, client *pjrt.Client) { t.Skipf("Skipping test: not enough devices: %d < %d", numDevices, numReplicas) return } + deviceAssignment := make([]int, numReplicas) + for i := range numReplicas { + deviceAssignment[i] = i + } t.Run("input-data-sharding", func(t *testing.T) { mesh := must1(shardy.NewDeviceMesh("data_mesh", []int{2}, []string{"data"})) @@ -60,11 +64,15 @@ func testShardy(t *testing.T, client *pjrt.Client) { must(fn.Return(output)) program := must1(builder.Build()) fmt.Printf("%s program:\n%s", t.Name(), program) - x0 := must1(client.BufferFromHost().ToDeviceNum(0).FromFlatDataWithDimensions( - []float32{0, 1, 2}, []int{1, 3}).Done()) - x1 := must1(client.BufferFromHost().ToDeviceNum(1).FromFlatDataWithDimensions( - []float32{0, 0.1, 0.2}, []int{1, 3}).Done()) - outputs := shardyCompileAndExecute(t, client, program, mesh, x0, x1) + x0 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[0]). + FromFlatDataWithDimensions([]float32{0, 1, 2}, []int{1, 3}). + Done()) + x1 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[1]). + FromFlatDataWithDimensions([]float32{0, 0.1, 0.2}, []int{1, 3}). + Done()) + outputs := shardyCompileAndExecute(t, client, program, deviceAssignment, x0, x1) requireBuffersEqual(t, []FlatAndDims{ {[]float32{3.3}, nil}, {[]float32{3.3}, nil}, @@ -86,11 +94,15 @@ func testShardy(t *testing.T, client *pjrt.Client) { must(fn.Return(output)) program := must1(builder.Build()) fmt.Printf("%s program:\n%s", t.Name(), program) - x0 := must1(client.BufferFromHost().ToDeviceNum(0).FromFlatDataWithDimensions( - []float32{0, 1, 2}, []int{1, 3}).Done()) - x1 := must1(client.BufferFromHost().ToDeviceNum(1).FromFlatDataWithDimensions( - []float32{0, 0.1, 0.2}, []int{1, 3}).Done()) - outputs := shardyCompileAndExecute(t, client, program, mesh, x0, x1) + x0 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[0]). + FromFlatDataWithDimensions([]float32{0, 1, 2}, []int{1, 3}). + Done()) + x1 := must1(client.BufferFromHost(). + ToDeviceNum(deviceAssignment[1]). + FromFlatDataWithDimensions([]float32{0, 0.1, 0.2}, []int{1, 3}). + Done()) + outputs := shardyCompileAndExecute(t, client, program, deviceAssignment, x0, x1) requireBuffersEqual(t, []FlatAndDims{ {[]float32{3}, []int{1}}, {[]float32{0.3}, []int{1}}, From 46f51151ec4bcec8665f6330766c507ed30d518a Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 10:47:42 +0100 Subject: [PATCH 19/30] Added support for logical device assignment. Updated tests. --- stablehlo_test.go | 4 +++- tests/gopjrt/shardy_test.go | 4 +++- types/shardy/devicemesh.go | 39 ++++++++++++++++++++++++++----------- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/stablehlo_test.go b/stablehlo_test.go index 5def2c7..59bd27b 100644 --- a/stablehlo_test.go +++ b/stablehlo_test.go @@ -47,6 +47,8 @@ func TestBuilder(t *testing.T) { b := New(t.Name()) mesh, err := shardy.NewDeviceMesh("mesh", []int{4, 2}, []string{"data", "model"}) require.NoError(t, err) + err = mesh.SetLogicalDeviceAssignment(7, 6, 5, 4, 3, 2, 1, 0) + require.NoError(t, err) b.WithShardy(mesh) fn := b.Main() @@ -77,7 +79,7 @@ func TestBuilder(t *testing.T) { program := string(must(b.Build())) fmt.Printf("%s program:\n%s", t.Name(), program) want := `module @TestBuilder_Sharding attributes {stablehlo.num_replicas = 1, stablehlo.num_partitions = 8} { - sdy.mesh @mesh = <["data"=4, "model"=2]> + sdy.mesh @mesh = <["data"=4, "model"=2], device_ids=[7, 6, 5, 4, 3, 2, 1, 0]> func.func @main(%arg0: tensor<16x128xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> }, %arg1: tensor<128x256xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"model"}, {}]> }) -> tensor<16x256xf32> { jax.result_info = "result", sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> diff --git a/tests/gopjrt/shardy_test.go b/tests/gopjrt/shardy_test.go index dcc9578..fbe941a 100644 --- a/tests/gopjrt/shardy_test.go +++ b/tests/gopjrt/shardy_test.go @@ -50,7 +50,9 @@ func testShardy(t *testing.T, client *pjrt.Client) { } t.Run("input-data-sharding", func(t *testing.T) { - mesh := must1(shardy.NewDeviceMesh("data_mesh", []int{2}, []string{"data"})) + mesh := must1(shardy.NewDeviceMesh("mesh", []int{2}, []string{"data"})) + // Reverse logical device assignment: notice it doesn't affect the order of the inputs. + must(mesh.SetLogicalDeviceAssignment(1, 0)) builder := stablehlo.New(t.Name()).WithShardy(mesh) fn := builder.Main() x := must1(fn.NamedInputWithSharding("arg0", shapes.Make(dtypes.F32, 2, 3), diff --git a/types/shardy/devicemesh.go b/types/shardy/devicemesh.go index 8978ce1..f74a2a4 100644 --- a/types/shardy/devicemesh.go +++ b/types/shardy/devicemesh.go @@ -82,15 +82,11 @@ func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, e } m := &DeviceMesh{ - name: name, - axesNames: axisNames, - axesSizes: shape, - nameToAxis: nameToAxis, - numDevices: numDevices, - logicalDeviceAssignment: make([]int, numDevices), - } - for i := range m.logicalDeviceAssignment { - m.logicalDeviceAssignment[i] = i + name: name, + axesNames: axisNames, + axesSizes: shape, + nameToAxis: nameToAxis, + numDevices: numDevices, } return m, nil } @@ -150,6 +146,10 @@ func (m *DeviceMesh) String() string { // // It returns an error if logicalDeviceAssignment has invalid device numbers or len(devices) != NumDevices(). func (m *DeviceMesh) SetLogicalDeviceAssignment(devices ...int) error { + if len(devices) == 0 { + m.logicalDeviceAssignment = nil + return nil + } if len(devices) != m.numDevices { return errors.Errorf("devices must have %d elements, got %d", m.numDevices, len(devices)) } @@ -164,12 +164,18 @@ func (m *DeviceMesh) SetLogicalDeviceAssignment(devices ...int) error { m.numDevices-1, device) } } - copy(m.logicalDeviceAssignment, devices) + m.logicalDeviceAssignment = slices.Clone(devices) return nil } // LogicalDeviceAssignment returns the list of devices in the mesh, in the order they appear in the mesh. +// +// It can return nil, if no assignment was set with SetLogicalDeviceAssignment() -- in which case it will +// default to a sequential assignment starting from 0. func (m *DeviceMesh) LogicalDeviceAssignment() []int { + if m.logicalDeviceAssignment == nil { + return nil + } return slices.Clone(m.logicalDeviceAssignment) } @@ -270,6 +276,17 @@ func (m *DeviceMesh) ToStableHLO() string { } w("%q=%d", axisName, m.axesSizes[i]) } - w("]>") + w("]") + if len(m.logicalDeviceAssignment) > 0 { + w(", device_ids=[") + for i, logicalDeviceId := range m.logicalDeviceAssignment { + if i > 0 { + w(", ") + } + w("%d", logicalDeviceId) + } + w("]") + } + w(">") return buf.String() } From 42c313d986ad10f5b413cf075c509734d5072e78 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 10:49:02 +0100 Subject: [PATCH 20/30] Reordered code, moving all distributed code to the end of builder.go. --- builder.go | 138 ++++++++++++++++++++++++++--------------------------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/builder.go b/builder.go index 46c4a7f..a26587f 100644 --- a/builder.go +++ b/builder.go @@ -56,75 +56,6 @@ func New(name string) *Builder { } } -// WithNumReplicas sets the number of replicas (for data parallelism). -// This is added as an attribute to the StableHLO module. -// -// Consider using WithShardy for distributed computation instead: other forms of distributed -// (collective) computation across devices are not tested and may not work. -func (b *Builder) WithNumReplicas(n int) *Builder { - b.numReplicas = n - return b -} - -// WithNumPartitions sets the number of partitions (for model parallelism). -// This is added as an attribute to the StableHLO module. -// -// Consider using WithShardy for distributed computation instead: other forms of distributed -// (collective) computation across devices are not tested and may not work. -func (b *Builder) WithNumPartitions(n int) *Builder { - b.numPartitions = n - return b -} - -// WithShardy enables distributed computation across the devices selected by the given meshes. -// -// This is the recommended way to do distributed (across devices) computation, and given the inputs -// with sharded information, Shardy will automatically distribute the computation, without you needing -// to specify any of the collective operations. -// -// Usually, there is only one meshes. But one can split the devices in different meshes. The meshes overlap -// the concrete devices used. -// -// See details of XLA Shardy in [1] -// -// [1] https://github.com/openxla/shardy -func (b *Builder) WithShardy(meshes ...*shardy.DeviceMesh) *Builder { - b.meshes = meshes - b.WithNumReplicas(1) - numDevices := 0 - for _, mesh := range meshes { - numDevices = max(numDevices, mesh.NumDevices()) - } - b.WithNumPartitions(numDevices) - return b -} - -// Meshes returns the meshes configured with WithShardy. -func (b *Builder) Meshes() []*shardy.DeviceMesh { - return b.meshes -} - -// NewShardingSpec creates a new ShardingSpec using the first mesh configured with WithShardy. -// It returns nil if no mesh was not configured. -// -// This is a shortcut to NewShardingSpecByMeshIx(0). -func (b *Builder) NewShardingSpec() *shardy.ShardingSpec { - if len(b.meshes) == 0 { - return nil - } - return shardy.NewShardingSpec(b.meshes[0]) -} - -// NewShardingSpecByMeshIx creates a new ShardingSpec for the meshIdx (the order given by WithShardy). -// -// It may return nil if meshIdx is out of range. -func (b *Builder) NewShardingSpecByMeshIx(meshIdx int) *shardy.ShardingSpec { - if meshIdx < 0 || meshIdx >= len(b.meshes) { - return nil - } - return shardy.NewShardingSpec(b.meshes[meshIdx]) -} - // elementWriter represents elements of ToStableHLO that know how to write themselves. type elementWriter interface { Write(w io.Writer, indentation string) error @@ -299,3 +230,72 @@ func (b *Builder) getChannelHandle(config *types.CollectiveConfig) literalStr { return literalStrF("#stablehlo.channel_handle", id, typ) } + +// WithNumReplicas sets the number of replicas (for data parallelism). +// This is added as an attribute to the StableHLO module. +// +// Consider using WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. +func (b *Builder) WithNumReplicas(n int) *Builder { + b.numReplicas = n + return b +} + +// WithNumPartitions sets the number of partitions (for model parallelism). +// This is added as an attribute to the StableHLO module. +// +// Consider using WithShardy for distributed computation instead: other forms of distributed +// (collective) computation across devices are not tested and may not work. +func (b *Builder) WithNumPartitions(n int) *Builder { + b.numPartitions = n + return b +} + +// WithShardy enables distributed computation across the devices selected by the given meshes. +// +// This is the recommended way to do distributed (across devices) computation, and given the inputs +// with sharded information, Shardy will automatically distribute the computation, without you needing +// to specify any of the collective operations. +// +// Usually, there is only one meshes. But one can split the devices in different meshes. The meshes overlap +// the concrete devices used. +// +// See details of XLA Shardy in [1] +// +// [1] https://github.com/openxla/shardy +func (b *Builder) WithShardy(meshes ...*shardy.DeviceMesh) *Builder { + b.meshes = meshes + b.WithNumReplicas(1) + numDevices := 0 + for _, mesh := range meshes { + numDevices = max(numDevices, mesh.NumDevices()) + } + b.WithNumPartitions(numDevices) + return b +} + +// Meshes returns the meshes configured with WithShardy. +func (b *Builder) Meshes() []*shardy.DeviceMesh { + return b.meshes +} + +// NewShardingSpec creates a new ShardingSpec using the first mesh configured with WithShardy. +// It returns nil if no mesh was not configured. +// +// This is a shortcut to NewShardingSpecByMeshIx(0). +func (b *Builder) NewShardingSpec() *shardy.ShardingSpec { + if len(b.meshes) == 0 { + return nil + } + return shardy.NewShardingSpec(b.meshes[0]) +} + +// NewShardingSpecByMeshIx creates a new ShardingSpec for the meshIdx (the order given by WithShardy). +// +// It may return nil if meshIdx is out of range. +func (b *Builder) NewShardingSpecByMeshIx(meshIdx int) *shardy.ShardingSpec { + if meshIdx < 0 || meshIdx >= len(b.meshes) { + return nil + } + return shardy.NewShardingSpec(b.meshes[meshIdx]) +} From b84e8cd0e4a466e500f3d15d7e95875f5b72c50e Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 20 Nov 2025 11:35:26 +0100 Subject: [PATCH 21/30] Fixed names of parameters --- types/shardy/devicemesh.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/types/shardy/devicemesh.go b/types/shardy/devicemesh.go index f74a2a4..9b59e33 100644 --- a/types/shardy/devicemesh.go +++ b/types/shardy/devicemesh.go @@ -44,12 +44,12 @@ type DeviceMesh struct { // // The default mapping of logical devices numbers to the mesh is sequential, starting from 0, but it can be // changed with the DeviceMesh.SetLogicalDeviceAssignment() method. -func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, error) { - if len(shape) != len(axisNames) { +func NewDeviceMesh(name string, axesSizes []int, axesNames []string) (*DeviceMesh, error) { + if len(axesSizes) != len(axesNames) { return nil, errors.Errorf("axesSizes and axesNames must have the same length, got %d and %d", - len(shape), len(axisNames)) + len(axesSizes), len(axesNames)) } - if len(shape) == 0 { + if len(axesSizes) == 0 { return nil, errors.New("DeviceMesh axesSizes cannot be empty") } @@ -59,9 +59,9 @@ func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, e "DeviceMesh name %q is not a valid StableHLO identifier, suggestion %q -- or use "+ "stablehlo.NormalizeIdentifier()", name, utils.NormalizeIdentifier(name)) } - axisNames = slices.Clone(axisNames) - for i, axisName := range axisNames { - if axisNames[i] != utils.NormalizeIdentifier(axisName) { + axesNames = slices.Clone(axesNames) + for i, axisName := range axesNames { + if axesNames[i] != utils.NormalizeIdentifier(axisName) { return nil, errors.Errorf( "DeviceMesh axis name %q at index %d is not a valid StableHLO identifier, suggestion %q -- or use "+ "stablehlo.NormalizeIdentifier()", axisName, i, utils.NormalizeIdentifier(axisName)) @@ -69,8 +69,8 @@ func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, e } numDevices := 1 - nameToAxis := make(map[string]int, len(shape)) - for i, name := range axisNames { + nameToAxis := make(map[string]int, len(axesSizes)) + for i, name := range axesNames { if name == "" { return nil, errors.Errorf("DeviceMesh axis name at index %d cannot be empty", i) } @@ -78,13 +78,13 @@ func NewDeviceMesh(name string, shape []int, axisNames []string) (*DeviceMesh, e return nil, errors.Errorf("DeviceMesh axis name %q is duplicated", name) } nameToAxis[name] = i - numDevices *= shape[i] + numDevices *= axesSizes[i] } m := &DeviceMesh{ name: name, - axesNames: axisNames, - axesSizes: shape, + axesNames: axesNames, + axesSizes: axesSizes, nameToAxis: nameToAxis, numDevices: numDevices, } From 95c3ac0f18260fd258c77b722d54e8a494615384 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Fri, 21 Nov 2025 08:29:34 +0100 Subject: [PATCH 22/30] s/DeviceMesh.Shape/DeviceMesh.AxesSizes --- types/shardy/devicemesh.go | 10 +++++----- types/shardy/devicemesh_test.go | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/types/shardy/devicemesh.go b/types/shardy/devicemesh.go index 9b59e33..37e08a9 100644 --- a/types/shardy/devicemesh.go +++ b/types/shardy/devicemesh.go @@ -105,13 +105,13 @@ func (m *DeviceMesh) Rank() int { return len(m.axesSizes) } -// AxisNames returns a copy of the mesh's axis names. -func (m *DeviceMesh) AxisNames() []string { +// AxesNames returns a copy of the mesh's axis names. +func (m *DeviceMesh) AxesNames() []string { return slices.Clone(m.axesNames) } -// Shape returns a copy of the mesh's axesSizes. -func (m *DeviceMesh) Shape() []int { +// AxesSizes returns a copy of the mesh's axesSizes. +func (m *DeviceMesh) AxesSizes() []int { shape := make([]int, len(m.axesSizes)) copy(shape, m.axesSizes) return shape @@ -170,7 +170,7 @@ func (m *DeviceMesh) SetLogicalDeviceAssignment(devices ...int) error { // LogicalDeviceAssignment returns the list of devices in the mesh, in the order they appear in the mesh. // -// It can return nil, if no assignment was set with SetLogicalDeviceAssignment() -- in which case it will +// It can return nil if no assignment was set with SetLogicalDeviceAssignment() -- in which case it will // default to a sequential assignment starting from 0. func (m *DeviceMesh) LogicalDeviceAssignment() []int { if m.logicalDeviceAssignment == nil { diff --git a/types/shardy/devicemesh_test.go b/types/shardy/devicemesh_test.go index 463252a..f8dda8a 100644 --- a/types/shardy/devicemesh_test.go +++ b/types/shardy/devicemesh_test.go @@ -107,28 +107,28 @@ func TestDeviceMesh(t *testing.T) { } }) - t.Run("AxisNames", func(t *testing.T) { + t.Run("AxesNames", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) require.NoError(t, err) - axisNames := mesh.AxisNames() + axisNames := mesh.AxesNames() assert.Equal(t, []string{"x", "y"}, axisNames) // Verify it returns a copy axisNames[0] = "modified" - assert.Equal(t, []string{"x", "y"}, mesh.AxisNames()) + assert.Equal(t, []string{"x", "y"}, mesh.AxesNames()) }) t.Run("Shape", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) require.NoError(t, err) - shape := mesh.Shape() - assert.Equal(t, []int{2, 4}, shape) + axesSizes := mesh.AxesSizes() + assert.Equal(t, []int{2, 4}, axesSizes) // Verify it returns a copy - shape[0] = 99 - assert.Equal(t, []int{2, 4}, mesh.Shape()) + axesSizes[0] = 99 + assert.Equal(t, []int{2, 4}, mesh.AxesSizes()) }) t.Run("AxisSize", func(t *testing.T) { From 68ba74721b6305a6745f2a9044ebe9a7afb55cfc Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Fri, 21 Nov 2025 16:09:10 +0100 Subject: [PATCH 23/30] Fixed AddShardedAxis. --- types/shardy/shardingspec.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/types/shardy/shardingspec.go b/types/shardy/shardingspec.go index 4f2b17f..8b9e5ff 100644 --- a/types/shardy/shardingspec.go +++ b/types/shardy/shardingspec.go @@ -68,9 +68,9 @@ func NewShardingSpec(mesh *DeviceMesh) *ShardingSpec { // AddShardedAxis adds a new sharded axis to the ShardingSpec using one or more mesh axes. // // It returns itself, so calls can be chained. -func (s *ShardingSpec) AddShardedAxis(meshAxisName string, moreMeshAxesNames ...string) *ShardingSpec { - axisSpec := TensorAxisSpec{MeshAxes: []MeshAxisSpec{{AxisName: meshAxisName}}} - for _, meshAxisName := range moreMeshAxesNames { +func (s *ShardingSpec) AddShardedAxis(meshAxesNames ...string) *ShardingSpec { + axisSpec := TensorAxisSpec{} + for _, meshAxisName := range meshAxesNames { axisSpec.MeshAxes = append(axisSpec.MeshAxes, MeshAxisSpec{AxisName: meshAxisName}) } s.Axes = append(s.Axes, axisSpec) From f0919531c2de53753bab34d2b030018271750ce1 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sat, 22 Nov 2025 08:15:25 +0100 Subject: [PATCH 24/30] ReturnWithShardingAndAttributes accept empty sharding specs list. --- function.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/function.go b/function.go index 0636291..addd9e9 100644 --- a/function.go +++ b/function.go @@ -260,10 +260,14 @@ func (fn *Function) Return(values ...*Value) error { // // The shardingSpecs slice of ShardingSpecs must have the same length as the values slice. // Each ShardingSpec can be nil, in which case the default sharding is replicated across all devices. +// If shardingSpecs is nil, this behaves just like ReturnWithAttributes. // -// The attributes slice of maps can be set to nil, if there are no attributes. +// The attributes slice of maps can be set to nil if there are no attributes. func (fn *Function) ReturnWithShardingAndAttributes(values []*Value, shardingSpecs []*shardy.ShardingSpec, attributes []map[string]any) error { + if len(shardingSpecs) == 0 { + return fn.ReturnWithAttributes(values, attributes) + } if len(values) != len(shardingSpecs) { return errors.Errorf("Function.ReturnWithShardingAndAttributes requires the same number of values and sharding specs, got %d and %d", len(values), len(shardingSpecs)) } From 86595fbce8d74237e803deac5b097e80c37613c4 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sat, 22 Nov 2025 08:23:45 +0100 Subject: [PATCH 25/30] ReturnWithAttributes accept nil attributes. --- function.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/function.go b/function.go index addd9e9..de4fd21 100644 --- a/function.go +++ b/function.go @@ -251,8 +251,7 @@ func (fn *Function) ConstantFromFlatAndDimensions(flat any, dimensions ...int) ( // If you are doing distributed computation, you can use WithReturnShardingSpecs to specify // the sharding requirements for each of the return values. func (fn *Function) Return(values ...*Value) error { - attributes := make([]map[string]any, len(values)) - return fn.ReturnWithAttributes(values, attributes) + return fn.ReturnWithAttributes(values, nil) } // ReturnWithShardingAndAttributes is a convenience function to call ReturnWithAttributes with the given sharding @@ -295,8 +294,10 @@ func (fn *Function) ReturnWithAttributes(values []*Value, attributes []map[strin if len(values) == 0 { return errors.New("Function.Return requires at least one return value") } - if len(values) != len(attributes) { - return errors.Errorf("Function.ReturnWithAttributes requires the same number of values and attributes, got %d and %d", len(values), len(attributes)) + if len(attributes) > 0 && len(values) != len(attributes) { + return errors.Errorf( + "if attributes is defined (!=nil) Function.ReturnWithAttributes requires the same number of "+ + "values and attributes, got %d and %d", len(values), len(attributes)) } fn.Returned = true outputValues := make([]*Value, len(values)) @@ -305,14 +306,15 @@ func (fn *Function) ReturnWithAttributes(values []*Value, attributes []map[strin return errors.New("Function.Return given values that are not owned by the function") } outputValues[i] = &Value{ - fn: fn, - name: value.name, - shape: value.shape, - Attributes: attributes[i], + fn: fn, + name: value.name, + shape: value.shape, + } + if len(attributes) > 0 { + outputValues[i].Attributes = attributes[i] } } fn.Outputs = outputValues - stmt := &Statement{ Builder: fn.Builder, Function: fn, From 81ce241b4ee63d16ed7173bd059905e956a0ab99 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sun, 23 Nov 2025 07:45:16 +0100 Subject: [PATCH 26/30] Collective operations: don't set up channel_handle if not explicitly requested: it's not compatible with SPMD. --- collective.go | 25 +++++++++++++++++++------ types/ops.go | 8 ++++---- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/collective.go b/collective.go index 123a7ed..6f8dc3d 100644 --- a/collective.go +++ b/collective.go @@ -49,7 +49,7 @@ func formatReplicaGroups(groups [][]int) literalStr { // notice it's not the device numbers by the replica numbers (there is an indirection). // Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device // numbers. E.g., `[[0, 1, 2, 3]]`. -// - config: Optional configuration of the channels to be used. This is not needed for SPMD programs. +// - config: Optional configuration of the channels to be used. This is shouldn't be used for SPMD programs. // // Consider using Builder.WithShardy for distributed computation instead: other forms of distributed // (collective) computation across devices are not tested and may not work. @@ -80,7 +80,9 @@ func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types stmt := fn.addOp(op, outputShape, operand) stmt.Attributes = map[string]any{ "replica_groups": formatReplicaGroups(replicaGroups), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) } return stmt.Outputs[0], nil } @@ -144,7 +146,9 @@ func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, stmt := fn.addMultiOp(op, outputShapes, operands) stmt.Attributes = map[string]any{ "replica_groups": formatReplicaGroups(replicaGroups), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) } if cfg != nil && cfg.UseGlobalDeviceIDs { stmt.Attributes["use_global_device_ids"] = true @@ -185,7 +189,9 @@ func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config . stmt.Attributes = map[string]any{ "replica_groups": formatReplicaGroups(replicaGroups), "all_gather_dim": int64(allGatherDim), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) } if cfg != nil && cfg.UseGlobalDeviceIDs { stmt.Attributes["use_global_device_ids"] = true @@ -230,7 +236,9 @@ func AllToAll(operand *Value, replicaGroups [][]int, splitDimension, concatDimen "split_dimension": int64(splitDimension), "concat_dimension": int64(concatDimension), "split_count": int64(splitCount), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) } if cfg != nil && cfg.UseGlobalDeviceIDs { stmt.Attributes["use_global_device_ids"] = true @@ -288,7 +296,12 @@ func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*ty stmt := fn.addOp(op, outputShape, operand) stmt.Attributes = map[string]any{ "source_target_pairs": formatSourceTargetPairs(sourceTargetPairs), - "channel_handle": fn.Builder.getChannelHandle(cfg), + } + if cfg != nil { + stmt.Attributes["channel_handle"] = fn.Builder.getChannelHandle(cfg) + } + if cfg != nil && cfg.UseGlobalDeviceIDs { + stmt.Attributes["use_global_device_ids"] = true } return stmt.Outputs[0], nil } diff --git a/types/ops.go b/types/ops.go index e21dd4b..021e9f7 100644 --- a/types/ops.go +++ b/types/ops.go @@ -222,15 +222,15 @@ const ( // CollectiveConfig provides advanced, optional configuration for collective operations. // Pass this as the last (optional) argument to collective ops. type CollectiveConfig struct { + // ChannelType specifies the communication dimension. + // Defaults to CrossReplica (0). + ChannelType ChannelType + // ChannelID, if non-nil, forces a specific channel ID (the 'handle'). // If nil, a unique ID will be automatically generated. // This is **required** for MPMD (multi-program, multi-data) to manually link ops across programs. ChannelID *int - // ChannelType specifies the communication dimension. - // Defaults to CrossReplica (0). - ChannelType ChannelType - // UseGlobalDeviceIDs changes the interpretation of replica_groups // from replica IDs to global device IDs. // This only applies to AllReduce, not CollectiveBroadcast. From 16b9a7c21e2a60cc34ff77bd32a618b9f94c2587 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sun, 23 Nov 2025 07:53:39 +0100 Subject: [PATCH 27/30] Updated CHANGELOG. --- docs/CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index a5ea621..86e6321 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,7 +1,10 @@ # v0.2.0: ... - `Function.Input` and `Function.NamedInput`: (change in API) they now may return an error, if the name is duplicate. -- `AllReduce` now supports arbitrary number of inputs, to be reduced at once. +- Collective ops: + - `AllReduce` now supports arbitrary number of inputs, to be reduced at once. + - `channel_handle` attribute is not generated by default: it is not compatible with SPMD, and probably will + only be useful once `Send`/`Recv` are supported. - Added XLA Shardy support: - Added `shardy.DeviceMesh` and `shardy.ShardingSpec` types. - Added `Builder.WithShardy(mesh)` From e3439a217bbcc2b98064c081984ee5ba6c576dae Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Sun, 23 Nov 2025 12:42:31 +0100 Subject: [PATCH 28/30] Fixed attributes for outputs of functions. --- function.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/function.go b/function.go index de4fd21..02e138d 100644 --- a/function.go +++ b/function.go @@ -400,7 +400,8 @@ func (fn *Function) Write(writer io.Writer, indentation string) error { w(") :\n") } else if normalFunction { w(") -> ") - if len(fn.Outputs) > 1 { + encloseOutputInParenthesis := len(fn.Outputs) > 1 || (len(fn.Outputs) == 1 && len(fn.Outputs[0].Attributes) > 0) + if encloseOutputInParenthesis { w("(") } for i, output := range fn.Outputs { @@ -410,7 +411,7 @@ func (fn *Function) Write(writer io.Writer, indentation string) error { w(output.shape.ToStableHLO()) writeAttributes(writer, indentation, output.Attributes, w) } - if len(fn.Outputs) > 1 { + if encloseOutputInParenthesis { w(")") } w(" {\n") From a9e5561bcb8b817e8ad4388c3dcad76ccc38ddac Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Tue, 25 Nov 2025 17:35:01 +0100 Subject: [PATCH 29/30] Bumping to v0.2.0-rc0 --- docs/CHANGELOG.md | 2 +- go.mod | 2 +- go.sum | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 86e6321..82b9ff4 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,4 +1,4 @@ -# v0.2.0: ... +# v0.2.0: (Release Candidate) Adding support for XLA Shardy - `Function.Input` and `Function.NamedInput`: (change in API) they now may return an error, if the name is duplicate. - Collective ops: diff --git a/go.mod b/go.mod index 0f56f5e..c6a2de7 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.24.0 toolchain go1.24.6 require ( - github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b + github.com/gomlx/gopjrt v0.10.0-rc0 github.com/janpfeifer/must v0.2.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index 3da24ba..806970d 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b h1:2jh5hJmb6YzGWL+rKUDYUpMdcmeeA+jDkKpfmevxl8g= github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b/go.mod h1:c8UENVGnxIDdihEL5HinlAdgR7RxTbEPLBppiMQF1ew= +github.com/gomlx/gopjrt v0.10.0-rc0 h1:EpF+JJYl3AUvU5ToKSfsuFnSPBxkPjbor93Ziak7OGA= +github.com/gomlx/gopjrt v0.10.0-rc0/go.mod h1:c8UENVGnxIDdihEL5HinlAdgR7RxTbEPLBppiMQF1ew= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/janpfeifer/go-benchmarks v0.1.1 h1:gLLy07/JrOKSnMWeUxSnjTdhkglgmrNR2IBDnR4kRqw= From 95c070b5aaf10aa77a0210b92a075baec4d05c8f Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Tue, 25 Nov 2025 17:40:03 +0100 Subject: [PATCH 30/30] Fixed test: added missing parenthesis for outputs with attributes. --- stablehlo_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stablehlo_test.go b/stablehlo_test.go index 59bd27b..6ed5fc6 100644 --- a/stablehlo_test.go +++ b/stablehlo_test.go @@ -80,10 +80,10 @@ func TestBuilder(t *testing.T) { fmt.Printf("%s program:\n%s", t.Name(), program) want := `module @TestBuilder_Sharding attributes {stablehlo.num_replicas = 1, stablehlo.num_partitions = 8} { sdy.mesh @mesh = <["data"=4, "model"=2], device_ids=[7, 6, 5, 4, 3, 2, 1, 0]> - func.func @main(%arg0: tensor<16x128xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> }, %arg1: tensor<128x256xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"model"}, {}]> }) -> tensor<16x256xf32> { + func.func @main(%arg0: tensor<16x128xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> }, %arg1: tensor<128x256xf32> { sdy.sharding = #sdy.sharding<@mesh, [{"model"}, {}]> }) -> (tensor<16x256xf32> { jax.result_info = "result", sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]> - } { + }) { %0 = "stablehlo.tanh"(%arg0) : (tensor<16x128xf32>) -> tensor<16x128xf32> %1 = "stablehlo.dot_general"(%0, %arg1) { dot_dimension_numbers = #stablehlo.dot<