diff --git a/go.mod b/go.mod index c6a2de7..2ccfa06 100644 --- a/go.mod +++ b/go.mod @@ -8,22 +8,18 @@ require ( github.com/gomlx/gopjrt v0.10.0-rc0 github.com/janpfeifer/must v0.2.0 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.11.1 github.com/x448/float16 v0.8.4 k8s.io/klog/v2 v2.130.1 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dmarkham/enumer v1.6.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/pascaldekloe/name v1.0.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/mod v0.27.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/tools v0.36.0 // indirect google.golang.org/protobuf v1.36.10 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) tool github.com/dmarkham/enumer diff --git a/go.sum b/go.sum index 806970d..50009f0 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,6 @@ github.com/dmarkham/enumer v1.6.1 h1:aSc9awYtZL07TUueWs40QcHtxTvHTAwG0EqrNsK45w4 github.com/dmarkham/enumer v1.6.1/go.mod h1:yixql+kDDQRYqcuBM2n9Vlt7NoT9ixgXhaXry8vmRg8= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b h1:2jh5hJmb6YzGWL+rKUDYUpMdcmeeA+jDkKpfmevxl8g= -github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b/go.mod h1:c8UENVGnxIDdihEL5HinlAdgR7RxTbEPLBppiMQF1ew= github.com/gomlx/gopjrt v0.10.0-rc0 h1:EpF+JJYl3AUvU5ToKSfsuFnSPBxkPjbor93Ziak7OGA= github.com/gomlx/gopjrt v0.10.0-rc0/go.mod h1:c8UENVGnxIDdihEL5HinlAdgR7RxTbEPLBppiMQF1ew= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -34,8 +32,6 @@ golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= diff --git a/internal/utils/set_test.go b/internal/utils/set_test.go index 776cc90..9d553a8 100644 --- a/internal/utils/set_test.go +++ b/internal/utils/set_test.go @@ -2,38 +2,70 @@ package utils import ( "testing" - - "github.com/stretchr/testify/assert" ) func TestSet(t *testing.T) { // Sets are created empty. s := MakeSet[int](10) - assert.Len(t, s, 0) + if len(s) != 0 { + t.Errorf("expected len 0, got %d", len(s)) + } // Check inserting and recovery. s.Insert(3, 7) - assert.Len(t, s, 2) - assert.True(t, s.Has(3)) - assert.True(t, s.Has(7)) - assert.False(t, s.Has(5)) + if len(s) != 2 { + t.Errorf("expected len 2, got %d", len(s)) + } + if !s.Has(3) { + t.Errorf("expected s.Has(3) to be true") + } + if !s.Has(7) { + t.Errorf("expected s.Has(7) to be true") + } + if s.Has(5) { + t.Errorf("expected s.Has(5) to be false") + } s2 := SetWith(5, 7) - assert.Len(t, s2, 2) - assert.True(t, s2.Has(5)) - assert.True(t, s2.Has(7)) - assert.False(t, s2.Has(3)) + if len(s2) != 2 { + t.Errorf("expected len 2, got %d", len(s2)) + } + if !s2.Has(5) { + t.Errorf("expected s2.Has(5) to be true") + } + if !s2.Has(7) { + t.Errorf("expected s2.Has(7) to be true") + } + if s2.Has(3) { + t.Errorf("expected s2.Has(3) to be false") + } s3 := s.Sub(s2) - assert.Len(t, s3, 1) - assert.True(t, s3.Has(3)) + if len(s3) != 1 { + t.Errorf("expected len 1, got %d", len(s3)) + } + if !s3.Has(3) { + t.Errorf("expected s3.Has(3) to be true") + } delete(s, 7) - assert.Len(t, s, 1) - assert.True(t, s.Has(3)) - assert.False(t, s.Has(7)) - assert.True(t, s.Equal(s3)) - assert.False(t, s.Equal(s2)) + if len(s) != 1 { + t.Errorf("expected len 1, got %d", len(s)) + } + if !s.Has(3) { + t.Errorf("expected s.Has(3) to be true") + } + if s.Has(7) { + t.Errorf("expected s.Has(7) to be false") + } + if !s.Equal(s3) { + t.Errorf("expected s.Equal(s3) to be true") + } + if s.Equal(s2) { + t.Errorf("expected s.Equal(s2) to be false") + } s4 := SetWith(-3) - assert.False(t, s.Equal(s4)) + if s.Equal(s4) { + t.Errorf("expected s.Equal(s4) to be false") + } } diff --git a/shapeinference/convolve_test.go b/shapeinference/convolve_test.go index 84c4a98..91eb094 100644 --- a/shapeinference/convolve_test.go +++ b/shapeinference/convolve_test.go @@ -1,11 +1,10 @@ package shapeinference import ( + "strings" "testing" "github.com/gomlx/stablehlo/types/shapes" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestConvolve(t *testing.T) { @@ -238,11 +237,20 @@ func TestConvolve(t *testing.T) { tc.outputBatch, tc.outputChannels, tc.outputSpatial, tc.channelGroupCount, tc.batchGroupCount) if tc.expectedError != "" { - require.ErrorContains(t, err, tc.expectedError) + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.expectedError) + } + if !strings.Contains(err.Error(), tc.expectedError) { + t.Fatalf("expected error containing %q, got %q", tc.expectedError, err.Error()) + } return } - require.NoError(t, err) - assert.Equal(t, tc.output, output) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !tc.output.Equal(output) { + t.Errorf("expected output %v, got %v", tc.output, output) + } }) } } diff --git a/shapeinference/shapeinference_test.go b/shapeinference/shapeinference_test.go index 057181f..64ac10f 100644 --- a/shapeinference/shapeinference_test.go +++ b/shapeinference/shapeinference_test.go @@ -2,13 +2,12 @@ package shapeinference import ( "fmt" + "strings" "testing" "github.com/gomlx/gopjrt/dtypes" "github.com/gomlx/stablehlo/internal/optypes" "github.com/gomlx/stablehlo/types/shapes" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // Aliases @@ -34,76 +33,116 @@ func TestBinaryOp(t *testing.T) { // Invalid data types check. var err error _, err = BinaryOp(optypes.And, S(F32), S(F32)) - require.Error(t, err) - _, err = BinaryOp(optypes.Multiply, S(Bool, 1), S(Bool, 1)) - require.Error(t, err) + if err == nil { + t.Error("expected error for And(F32, F32), got nil") + } _, err = BinaryOp(optypes.Multiply, S(Bool, 1), S(Bool, 1)) - require.Error(t, err) + if err == nil { + t.Error("expected error for Multiply(Bool, Bool), got nil") + } _, err = BinaryOp(optypes.Xor, S(F32, 1), S(F32, 1)) - require.Error(t, err) + if err == nil { + t.Error("expected error for Xor(F32, F32), got nil") + } // Invalid operation type (not binary op). _, err = BinaryOp(optypes.Exponential, S(F32), S(F32)) - require.Error(t, err) + if err == nil { + t.Error("expected error for Exponential(F32, F32), got nil") + } // The same shape should be ok. var output shapes.Shape intMatrixShape := S(I8, 3, 3) output, err = BinaryOp(optypes.Or, intMatrixShape, intMatrixShape) - require.NoError(t, err) - require.True(t, intMatrixShape.Equal(output)) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if !intMatrixShape.Equal(output) { + t.Errorf("expected output shape %s, got %s", intMatrixShape, output) + } // Scalar with matrix. scalarShape := S(F32) matrixShape := S(F32, 2, 3) //expectedShape := S(F32, 2, 3) output, err = BinaryOp(optypes.Add, scalarShape, scalarShape) - require.NoError(t, err) - require.True(t, scalarShape.Equal(output)) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if !scalarShape.Equal(output) { + t.Errorf("expected output shape %s, got %s", scalarShape, output) + } _, err = BinaryOp(optypes.Add, scalarShape, matrixShape) - require.Error(t, err) + if err == nil { + t.Error("expected error for Add(scalar, matrix), got nil") + } //require.True(t, expectedShape.Equal(output)) // Broadcasting: not provided in StableHLO. shape1 := S(F32, 2, 1, 3) shape2 := S(F32, 1, 4, 3) _, err = BinaryOp(optypes.Add, shape1, shape2) - require.Error(t, err) + if err == nil { + t.Error("expected error for Add(shape1, shape2), got nil") + } //expectedBroadcastShape := S(F32, 2, 4, 3) //require.True(t, expectedBroadcastShape.Equal(must1(BinaryOp(optypes.Multiply, shape1, shape2)))) // Matrix with scalar. _, err = BinaryOp(optypes.Add, matrixShape, scalarShape) - require.Error(t, err) + if err == nil { + t.Error("expected error for Add(matrix, scalar), got nil") + } //require.True(t, expectedShape.Equal(must1(BinaryOp(optypes.Add, matrixShape, scalarShape)))) // Invalid broadcasting shapes. invalidShape1 := S(F32, 2, 3) invalidShape2 := S(F32, 3, 2) _, err = BinaryOp(optypes.Add, invalidShape1, invalidShape2) - require.Error(t, err) + if err == nil { + t.Error("expected error for Add(invalidShape1, invalidShape2), got nil") + } +} + +func panics(t *testing.T, f func()) { + t.Helper() + defer func() { + if r := recover(); r == nil { + t.Error("expected panic, but code did not panic") + } + }() + f() } func TestUnaryOp(t *testing.T) { // Invalid data types check. - require.Panics(t, func() { must1(UnaryOp(optypes.Not, S(F32))) }) - require.Panics(t, func() { must1(UnaryOp(optypes.Not, S(dtypes.Complex64))) }) - require.Panics(t, func() { must1(UnaryOp(optypes.Negate, S(Bool))) }) + panics(t, func() { must1(UnaryOp(optypes.Not, S(F32))) }) + panics(t, func() { must1(UnaryOp(optypes.Not, S(dtypes.Complex64))) }) + panics(t, func() { must1(UnaryOp(optypes.Negate, S(Bool))) }) // Invalid operation type (not unary op). - require.Panics(t, func() { must1(UnaryOp(optypes.Add, S(F32))) }) - require.Panics(t, func() { must1(UnaryOp(optypes.Negate, S(U64))) }) + panics(t, func() { must1(UnaryOp(optypes.Add, S(F32))) }) + panics(t, func() { must1(UnaryOp(optypes.Negate, S(U64))) }) // Valid operations boolShape := S(Bool, 2, 3) - require.True(t, boolShape.Equal(must1(UnaryOp(optypes.Not, boolShape)))) + if out := must1(UnaryOp(optypes.Not, boolShape)); !boolShape.Equal(out) { + t.Errorf("expected %s, got %s", boolShape, out) + } intShape := S(I8, 3, 3) - require.True(t, intShape.Equal(must1(UnaryOp(optypes.Not, intShape)))) + if out := must1(UnaryOp(optypes.Not, intShape)); !intShape.Equal(out) { + t.Errorf("expected %s, got %s", intShape, out) + } floatShape := S(F32, 2, 3) - require.True(t, floatShape.Equal(must1(UnaryOp(optypes.Exponential, floatShape)))) - require.True(t, floatShape.Equal(must1(UnaryOp(optypes.Negate, floatShape)))) + if out := must1(UnaryOp(optypes.Exponential, floatShape)); !floatShape.Equal(out) { + t.Errorf("expected %s, got %s", floatShape, out) + } + if out := must1(UnaryOp(optypes.Negate, floatShape)); !floatShape.Equal(out) { + t.Errorf("expected %s, got %s", floatShape, out) + } } func TestGather(t *testing.T) { @@ -120,9 +159,13 @@ func TestGather(t *testing.T) { offsetOutputAxes, collapsedSliceAxes, operandBatchingAxes, startIndicesBatchingAxes, startIndexMap, sliceSizes, false) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } fmt.Printf("\tTest 1: outputShape=%s\n", output) - require.NoError(t, output.Check(F32, 3, 3, 2, 1)) + if err := output.Check(F32, 3, 3, 2, 1); err != nil { + t.Errorf("output check failed: %v", err) + } }) t.Run("2", func(t *testing.T) { @@ -139,9 +182,13 @@ func TestGather(t *testing.T) { offsetOutputAxes, collapsedSliceAxes, operandBatchingAxes, startIndicesBatchingAxes, startIndexMap, sliceSizes, false) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } fmt.Printf("\tTest 2: outputShape=%s\n", output) - require.NoError(t, output.Check(F32, 7, 3, 1, 8)) + if err := output.Check(F32, 7, 3, 1, 8); err != nil { + t.Errorf("output check failed: %v", err) + } }) t.Run("3", func(t *testing.T) { @@ -158,9 +205,13 @@ func TestGather(t *testing.T) { offsetOutputAxes, collapsedSliceAxes, operandBatchingAxes, startIndicesBatchingAxes, startIndexMap, sliceSizes, false) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } fmt.Printf("\tTest 3: outputShape=%s\n", output) - require.NoError(t, output.Check(F32, 8, 16)) + if err := output.Check(F32, 8, 16); err != nil { + t.Errorf("output check failed: %v", err) + } }) // Test from StableHLO's specification example in https://openxla.org/stablehlo/spec#gather @@ -178,9 +229,13 @@ func TestGather(t *testing.T) { offsetOutputAxes, collapsedSliceAxes, operandBatchingAxes, startIndicesBatchingAxes, startIndexMap, sliceSizes, false) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } fmt.Printf("\tTest 3: outputShape=%s\n", output) - require.NoError(t, output.Check(F32, 2, 2, 3, 2, 2)) + if err := output.Check(F32, 2, 2, 3, 2, 2); err != nil { + t.Errorf("output check failed: %v", err) + } }) } @@ -206,9 +261,15 @@ func TestScatter(t *testing.T) { operandBatchingAxes, indicesBatchingAxes, scatterAxesToOperandAxes1, indexVectorAxis1, updateComputationInputs1, updateComputationOutputs1) - require.NoError(t, err) - require.Len(t, outputs1, 1) - require.True(t, expected1.Equal(outputs1[0]), "Valid Case 1 Failed: Expected %s, got %s", expected1, outputs1[0]) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(outputs1) != 1 { + t.Fatalf("expected 1 output, got %d", len(outputs1)) + } + if !expected1.Equal(outputs1[0]) { + t.Errorf("Valid Case 1 Failed: Expected %s, got %s", expected1, outputs1[0]) + } // Case 2: Scattering into a higher-rank tensor // Scatter updates of shape [4] into operand[i, j, :], where [i, j] comes from indices. @@ -230,9 +291,15 @@ func TestScatter(t *testing.T) { operandBatchingAxes, indicesBatchingAxes, scatterAxesToOperandAxes2, indexVectorAxis2, updateComputationInputs2, updateComputationOutputs2) - require.NoError(t, err) - require.Len(t, outputs2, 1) - require.True(t, expected2.Equal(outputs2[0]), "Valid Case 2 Failed: Expected %s, got %s", expected2, outputs2[0]) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(outputs2) != 1 { + t.Fatalf("expected 1 output, got %d", len(outputs2)) + } + if !expected2.Equal(outputs2[0]) { + t.Errorf("Valid Case 2 Failed: Expected %s, got %s", expected2, outputs2[0]) + } // Case 3: Different indexVectorAxis // Same as case 2, but indices are [2, 2, 3] -> indexVectorAxis is 1 and different order of axes in the operand. @@ -251,9 +318,15 @@ func TestScatter(t *testing.T) { operandBatchingAxes, indicesBatchingAxes, scatterAxesToOperandAxes3, indexVectorAxis3, updateComputationInputs3, updateComputationOutputs3) - require.NoError(t, err) - require.Len(t, outputs3, 1) - require.True(t, expected3.Equal(outputs3[0]), "Valid Case 3 Failed (IndexVecAxis=1): Expected %s, got %s", expected3, outputs3[0]) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(outputs3) != 1 { + t.Fatalf("expected 1 output, got %d", len(outputs3)) + } + if !expected3.Equal(outputs3[0]) { + t.Errorf("Valid Case 3 Failed (IndexVecAxis=1): Expected %s, got %s", expected3, outputs3[0]) + } // Case 4: No insertedWindowAxes (scattering full slices) // Scatter updates of shape [9] into operand [10, 9] @@ -272,9 +345,15 @@ func TestScatter(t *testing.T) { operandBatchingAxes, indicesBatchingAxes, scatterAxesToOperandAxes4, indexVectorAxis4, updateComputationInputs4, updateComputationOutputs4) - require.NoError(t, err) - require.Len(t, outputs4, 1) - require.True(t, expected4.Equal(outputs4[0]), "Valid Case 4 Failed (No Window): Expected %s, got %s", expected4, outputs4[0]) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(outputs4) != 1 { + t.Fatalf("expected 1 output, got %d", len(outputs4)) + } + if !expected4.Equal(outputs4[0]) { + t.Errorf("Valid Case 4 Failed (No Window): Expected %s, got %s", expected4, outputs4[0]) + } // Case 5: rearranging the output axes: operand5 := S(F32, 2, 5, 2) @@ -291,9 +370,15 @@ func TestScatter(t *testing.T) { operandBatchingAxes, indicesBatchingAxes, scatterAxesToOperandAxes5, indexVectorAxis5, updateComputationInputs5, updateComputationOutputs5) - require.NoError(t, err) - require.Len(t, outputs5, 1) - require.True(t, operand5.Equal(outputs5[0]), "Valid Case 5 Failed (No Window): Expected %s, got %s", operand5, outputs5[0]) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(outputs5) != 1 { + t.Fatalf("expected 1 output, got %d", len(outputs5)) + } + if !operand5.Equal(outputs5[0]) { + t.Errorf("Valid Case 5 Failed (No Window): Expected %s, got %s", operand5, outputs5[0]) + } } func TestSlice(t *testing.T) { @@ -307,8 +392,12 @@ func TestSlice(t *testing.T) { strides1 := []int{1} expected1 := S(F32, 6) output1, err := Slice(operand1, starts1, limits1, strides1) - require.NoError(t, err) - require.True(t, expected1.Equal(output1), "%s Valid Case 1 Failed: Expected %s, got %s", opName, expected1, output1) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected1.Equal(output1) { + t.Errorf("%s Valid Case 1 Failed: Expected %s, got %s", opName, expected1, output1) + } // Case 2: 2D slice with stride 1 operand2 := S(I32, 5, 6) @@ -317,8 +406,12 @@ func TestSlice(t *testing.T) { strides2 := []int{1, 1} expected2 := S(I32, 3, 3) output2, err := Slice(operand2, starts2, limits2, strides2) - require.NoError(t, err) - require.True(t, expected2.Equal(output2), "%s Valid Case 2 Failed: Expected %s, got %s", opName, expected2, output2) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected2.Equal(output2) { + t.Errorf("%s Valid Case 2 Failed: Expected %s, got %s", opName, expected2, output2) + } // Case 3: 3D slice with different strides operand3 := S(Bool, 10, 8, 6) @@ -330,8 +423,12 @@ func TestSlice(t *testing.T) { // Dim 2: (6-1)/1 = 5 -> 5 elements (indices 1, 2, 3, 4, 5) expected3 := S(Bool, 5, 3, 5) output3, err := Slice(operand3, starts3, limits3, strides3) - require.NoError(t, err) - require.True(t, expected3.Equal(output3), "%s Valid Case 3 Failed: Expected %s, got %s", opName, expected3, output3) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected3.Equal(output3) { + t.Errorf("%s Valid Case 3 Failed: Expected %s, got %s", opName, expected3, output3) + } // Case 4: Slice resulting in size 1 dimension operand4 := S(F32, 10) @@ -340,8 +437,12 @@ func TestSlice(t *testing.T) { strides4 := []int{1} expected4 := S(F32, 1) output4, err := Slice(operand4, starts4, limits4, strides4) - require.NoError(t, err) - require.True(t, expected4.Equal(output4), "%s Valid Case 4 Failed: Expected %s, got %s", opName, expected4, output4) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected4.Equal(output4) { + t.Errorf("%s Valid Case 4 Failed: Expected %s, got %s", opName, expected4, output4) + } // Case 5: Slice taking full dimension with stride > 1 operand5 := S(I8, 7) @@ -351,8 +452,12 @@ func TestSlice(t *testing.T) { // Dim 0: (7-0)/2 = 3.5 -> 4 elements (indices 0, 2, 4, 6) expected5 := S(I8, 4) output5, err := Slice(operand5, starts5, limits5, strides5) - require.NoError(t, err) - require.True(t, expected5.Equal(output5), "%s Valid Case 5 Failed: Expected %s, got %s", opName, expected5, output5) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected5.Equal(output5) { + t.Errorf("%s Valid Case 5 Failed: Expected %s, got %s", opName, expected5, output5) + } // --- Error Cases --- operand := S(F32, 10, 5) // Rank 2 @@ -362,43 +467,63 @@ func TestSlice(t *testing.T) { // Error 1: Invalid operand DType _, err = Slice(shapes.Shape{DType: dtypes.InvalidDType, Dimensions: []int{10}}, []int{0}, []int{5}, []int{1}) - require.Error(t, err, "%s Error Case 1 Failed: Invalid operand DType", opName) + if err == nil { + t.Errorf("%s Error Case 1 Failed: Invalid operand DType", opName) + } // Error 2: Incorrect length for starts _, err = Slice(operand, []int{1}, validLimits, validStrides) - require.Error(t, err, "%s Error Case 2 Failed: len(starts) != rank", opName) + if err == nil { + t.Errorf("%s Error Case 2 Failed: len(starts) != rank", opName) + } // Error 3: Incorrect length for limits _, err = Slice(operand, validStarts, []int{8}, validStrides) - require.Error(t, err, "%s Error Case 3 Failed: len(limits) != rank", opName) + if err == nil { + t.Errorf("%s Error Case 3 Failed: len(limits) != rank", opName) + } // Error 4: Incorrect length for strides _, err = Slice(operand, validStarts, validLimits, []int{1}) - require.Error(t, err, "%s Error Case 4 Failed: len(strides) != rank", opName) + if err == nil { + t.Errorf("%s Error Case 4 Failed: len(strides) != rank", opName) + } // Error 5: Zero stride _, err = Slice(operand, validStarts, validLimits, []int{1, 0}) - require.Error(t, err, "%s Error Case 5 Failed: Zero stride", opName) + if err == nil { + t.Errorf("%s Error Case 5 Failed: Zero stride", opName) + } // Error 6: Negative stride _, err = Slice(operand, validStarts, validLimits, []int{-1, 1}) - require.Error(t, err, "%s Error Case 6 Failed: Negative stride", opName) + if err == nil { + t.Errorf("%s Error Case 6 Failed: Negative stride", opName) + } // Error 7: Start index < 0 _, err = Slice(operand, []int{-1, 1}, validLimits, validStrides) - require.Error(t, err, "%s Error Case 7 Failed: Start < 0", opName) + if err == nil { + t.Errorf("%s Error Case 7 Failed: Start < 0", opName) + } // Error 8: Start index >= dimSize _, err = Slice(operand, []int{10, 1}, validLimits, validStrides) - require.Error(t, err, "%s Error Case 8 Failed: Start >= dimSize", opName) + if err == nil { + t.Errorf("%s Error Case 8 Failed: Start >= dimSize", opName) + } // Error 9: Limit index < start index _, err = Slice(operand, validStarts, []int{0, 4}, validStrides) // limit[0]=0 < start[0]=1 - require.Error(t, err, "%s Error Case 9 Failed: Limit < Start", opName) + if err == nil { + t.Errorf("%s Error Case 9 Failed: Limit < Start", opName) + } // Error 10: Limit index > dimSize _, err = Slice(operand, validStarts, []int{8, 6}, validStrides) // limit[1]=6 > dimSize[1]=5 - require.Error(t, err, "%s Error Case 10 Failed: Limit > dimSize", opName) + if err == nil { + t.Errorf("%s Error Case 10 Failed: Limit > dimSize", opName) + } } func TestArgMinMax(t *testing.T) { @@ -408,51 +533,65 @@ func TestArgMinMax(t *testing.T) { operand1 := S(F32, 10) expected1 := S(I32) output1 := must1(ArgMinMax(operand1, 0, I32)) - require.True(t, expected1.Equal(output1), "Valid Case 1 Failed: Expected %s, got %s", expected1, output1) + if !expected1.Equal(output1) { + t.Errorf("Valid Case 1 Failed: Expected %s, got %s", expected1, output1) + } // Case 2: 2D tensor, single axis operand2 := S(F32, 5, 6) expected2 := S(I8, 5) output2 := must1(ArgMinMax(operand2, 1, expected2.DType)) - require.True(t, expected2.Equal(output2), "Valid Case 2 Failed: Expected %s, got %s", expected2, output2) + if !expected2.Equal(output2) { + t.Errorf("Valid Case 2 Failed: Expected %s, got %s", expected2, output2) + } // Case 3: 3D tensor, multiple axes operand3 := S(F32, 4, 5, 6) expected3 := S(U64, 5, 6) output3 := must1(ArgMinMax(operand3, 0, expected3.DType)) - require.True(t, expected3.Equal(output3), "Valid Case 3 Failed: Expected %s, got %s", expected3, output3) + if !expected3.Equal(output3) { + t.Errorf("Valid Case 3 Failed: Expected %s, got %s", expected3, output3) + } // --- Error Cases --- // Error 1: Invalid operand DType - require.Panics(t, func() { + panics(t, func() { must1(ArgMinMax(shapes.Make(dtypes.InvalidDType, 10), 0, I32)) - }, "Error Case 1 Failed: Invalid operand DType") + }) // Error 2: Invalid axis (out of bounds) - require.Panics(t, func() { + panics(t, func() { must1(ArgMinMax(operand1, 1, I32)) // operand1 is rank 1, axis 1 invalid - }, "Error Case 2 Failed: Invalid axis (out of bounds)") + }) // Error 3: Negative axis - require.Panics(t, func() { + panics(t, func() { must1(ArgMinMax(operand2, -1, I32)) - }, "Error Case 3 Failed: Negative axis") + }) } func TestIsFinite(t *testing.T) { // Positive case: float64 tensor. f64Shape := S(dtypes.Float64, 2, 3) output, err := IsFinite(f64Shape) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } expected := S(Bool, 2, 3) - require.True(t, expected.Equal(output)) + if !expected.Equal(output) { + t.Errorf("expected %s, got %s", expected, output) + } // Check non-float type. _, err = IsFinite(S(Bool)) - require.Error(t, err) + if err == nil { + t.Error("expected error for IsFinite(Bool), got nil") + } _, err = IsFinite(S(I32)) - require.Error(t, err) + if err == nil { + t.Error("expected error for IsFinite(I32), got nil") + } } func TestReduceWindow(t *testing.T) { @@ -664,15 +803,23 @@ func TestReduceWindow(t *testing.T) { ) if tc.expectError { - require.Error(t, err, "Expected an error for test case: %s", tc.name) - if tc.errorMessageContains != "" { - assert.Contains(t, err.Error(), tc.errorMessageContains, "Error message mismatch for: %s", tc.name) + if err == nil { + t.Errorf("Expected an error for test case: %s", tc.name) + } + if tc.errorMessageContains != "" && err != nil { + if !strings.Contains(err.Error(), tc.errorMessageContains) { + t.Errorf("Error message mismatch for: %s, expected to contain %q, got %q", tc.name, tc.errorMessageContains, err.Error()) + } } } else { - require.NoError(t, err, "Did not expect an error for test case: %s (error was: %v)", tc.name, err) - assert.True(t, tc.expectedShape.Equal(outputShape[0]), - "Mismatch in output shape for test case: %s. Expected %s, Got %s", - tc.name, tc.expectedShape, outputShape) + if err != nil { + t.Errorf("Did not expect an error for test case: %s (error was: %v)", tc.name, err) + } + if len(outputShape) > 0 { + if !tc.expectedShape.Equal(outputShape[0]) { + t.Errorf("Mismatch in output shape for test case: %s. Expected %s, Got %s", tc.name, tc.expectedShape, outputShape) + } + } } }) } @@ -686,12 +833,16 @@ func TestDotGeneral(t *testing.T) { lhs, []int{1}, []int{3, 0}, rhs, []int{3}, []int{0, 2}, F32) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } // Batch dims: 5 , 2 // Contracting dims: 3 // Cross dims: 4 (lhs) and 1 (rhs) fmt.Printf("\tdotgeneral.shape=%s\n", output) - assert.NoError(t, output.Check(F32, 5, 2, 4, 1)) + if err := output.Check(F32, 5, 2, 4, 1); err != nil { + t.Errorf("output check failed: %v", err) + } } func TestPad(t *testing.T) { @@ -703,8 +854,12 @@ func TestPad(t *testing.T) { paddingInterior := []int{0} expected := S(F32, 10) output, err := Pad(operand, fillValue, paddingStart, paddingEnd, paddingInterior) - require.NoError(t, err) - require.True(t, expected.Equal(output), "Expected %s, got %s", expected, output) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected.Equal(output) { + t.Errorf("Expected %s, got %s", expected, output) + } }) t.Run("2DWithInterior", func(t *testing.T) { @@ -715,8 +870,12 @@ func TestPad(t *testing.T) { paddingInterior := []int{1, 1} expected := S(F32, 6, 9) output, err := Pad(operand, fillValue, paddingStart, paddingEnd, paddingInterior) - require.NoError(t, err) - require.True(t, expected.Equal(output), "Expected %s, got %s", expected, output) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected.Equal(output) { + t.Errorf("Expected %s, got %s", expected, output) + } }) t.Run("3DPadding", func(t *testing.T) { @@ -727,8 +886,12 @@ func TestPad(t *testing.T) { paddingInterior := []int{0, 0, 0} expected := S(F32, 4, 5, 3) output, err := Pad(operand, fillValue, paddingStart, paddingEnd, paddingInterior) - require.NoError(t, err) - require.True(t, expected.Equal(output), "Expected %s, got %s", expected, output) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected.Equal(output) { + t.Errorf("Expected %s, got %s", expected, output) + } }) t.Run("ErrorWrongFillValueDType", func(t *testing.T) { @@ -738,7 +901,9 @@ func TestPad(t *testing.T) { paddingEnd := []int{1} paddingInterior := []int{0} _, err := Pad(operand, fillValue, paddingStart, paddingEnd, paddingInterior) - require.Error(t, err) + if err == nil { + t.Error("expected error for Pad with wrong fill value dtype, got nil") + } }) t.Run("ErrorNonScalarFillValue", func(t *testing.T) { @@ -748,7 +913,9 @@ func TestPad(t *testing.T) { paddingEnd := []int{1} paddingInterior := []int{0} _, err := Pad(operand, fillValue, paddingStart, paddingEnd, paddingInterior) - require.Error(t, err) + if err == nil { + t.Error("expected error for Pad with non-scalar fill value, got nil") + } }) t.Run("ErrorMismatchedRank", func(t *testing.T) { @@ -758,7 +925,9 @@ func TestPad(t *testing.T) { paddingEnd := []int{1, 1} paddingInterior := []int{0, 0} _, err := Pad(operand, fillValue, paddingStart, paddingEnd, paddingInterior) - require.Error(t, err) + if err == nil { + t.Error("expected error for Pad with mismatched rank, got nil") + } }) t.Run("NegativePadding", func(t *testing.T) { @@ -769,8 +938,12 @@ func TestPad(t *testing.T) { paddingInterior := []int{0} expected := S(F32, 1) output, err := Pad(operand, fillValue, paddingStart, paddingEnd, paddingInterior) - require.NoError(t, err) - require.True(t, expected.Equal(output), "Expected %s, got %s", expected, output) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !expected.Equal(output) { + t.Errorf("Expected %s, got %s", expected, output) + } }) t.Run("ErrorNegativeInterior", func(t *testing.T) { @@ -780,7 +953,9 @@ func TestPad(t *testing.T) { paddingEnd := []int{0} paddingInterior := []int{-1} _, err := Pad(operand, fillValue, paddingStart, paddingEnd, paddingInterior) - require.Error(t, err) + if err == nil { + t.Error("expected error for Pad with negative interior, got nil") + } }) } @@ -790,27 +965,43 @@ func TestCollectiveOps(t *testing.T) { t.Run("AllGather", func(t *testing.T) { output, err := AllGather(operand, replicaGroups, 1) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } expected := S(F32, 2, 8) - require.True(t, expected.Equal(output), "Expected %s, got %s", expected, output) + if !expected.Equal(output) { + t.Errorf("Expected %s, got %s", expected, output) + } _, err = AllGather(operand, replicaGroups, 2) - require.Error(t, err) + if err == nil { + t.Error("expected error for AllGather with invalid dimension, got nil") + } }) t.Run("AllToAll", func(t *testing.T) { output, err := AllToAll(operand, replicaGroups, 1, 0, 2) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } expected := S(F32, 4, 2) - require.True(t, expected.Equal(output), "Expected %s, got %s", expected, output) + if !expected.Equal(output) { + t.Errorf("Expected %s, got %s", expected, output) + } _, err = AllToAll(operand, replicaGroups, 2, 0, 2) - require.Error(t, err) + if err == nil { + t.Error("expected error for AllToAll with invalid dimension, got nil") + } }) t.Run("CollectivePermute", func(t *testing.T) { output, err := CollectivePermute(operand, [][2]int{{0, 1}}) - require.NoError(t, err) - require.True(t, operand.Equal(output), "Expected %s, got %s", operand, output) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !operand.Equal(output) { + t.Errorf("Expected %s, got %s", operand, output) + } }) } diff --git a/stablehlo_test.go b/stablehlo_test.go index 6ed5fc6..595f1f3 100644 --- a/stablehlo_test.go +++ b/stablehlo_test.go @@ -2,13 +2,12 @@ package stablehlo import ( "fmt" + "strings" "testing" "github.com/gomlx/gopjrt/dtypes" "github.com/gomlx/stablehlo/types/shapes" "github.com/gomlx/stablehlo/types/shardy" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func must[T any](value T, err error) T { @@ -25,7 +24,9 @@ func TestBuilder(t *testing.T) { c1 := must(fn.ConstantFromScalar(1.0)) c2 := must(fn.ConstantFromScalar(2.0)) sum := must(Add(c1, c2)) - require.NoError(t, fn.Return(sum)) + if err := fn.Return(sum); err != nil { + t.Fatalf("expected no error, got %v", err) + } program := string(must(b.Build())) fmt.Printf("%s program:\n%s", t.Name(), program) want := `module @TestBuilder_no_inputs { @@ -46,9 +47,13 @@ func TestBuilder(t *testing.T) { t.Run("Sharding", func(t *testing.T) { b := New(t.Name()) mesh, err := shardy.NewDeviceMesh("mesh", []int{4, 2}, []string{"data", "model"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } err = mesh.SetLogicalDeviceAssignment(7, 6, 5, 4, 3, 2, 1, 0) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } b.WithShardy(mesh) fn := b.Main() @@ -74,7 +79,9 @@ func TestBuilder(t *testing.T) { []map[string]any{ {"jax.result_info": "result"}, }) - require.NoError(t, err) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } program := string(must(b.Build())) fmt.Printf("%s program:\n%s", t.Name(), program) @@ -98,7 +105,9 @@ func TestBuilder(t *testing.T) { } } ` - require.Equal(t, want, program) + if want != program { + t.Fatalf("programs don't match.\nWant:\n%s\nGot:\n%s", want, program) + } }) t.Run("with inputs", func(t *testing.T) { @@ -109,7 +118,9 @@ func TestBuilder(t *testing.T) { lhs := must(fn.NamedInput("lhs", shape)) rhs := must(fn.NamedInput("rhs", shape)) sum := must(Add(lhs, rhs)) - require.NoError(t, fn.Return(sum)) + if err := fn.Return(sum); err != nil { + t.Fatalf("expected no error, got %v", err) + } program := string(must(builder.Build())) fmt.Printf("%s program:\n%s", t.Name(), program) want := `module @TestBuilder_with_inputs { @@ -131,10 +142,16 @@ func TestBuilder_Errors(t *testing.T) { b := New("test_program") fn := b.NewFunction("not_main", nil) c1 := must(fn.ConstantFromScalar(1.0)) - require.NoError(t, fn.Return(c1)) + if err := fn.Return(c1); err != nil { + t.Fatalf("expected no error, got %v", err) + } _, err := b.Build() - require.Error(t, err) - assert.Contains(t, err.Error(), "program must have a main function") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "program must have a main function") { + t.Fatalf("error message %q does not contain expected substring", err.Error()) + } }) } @@ -148,6 +165,8 @@ func TestNormalizeIdentifier(t *testing.T) { } for _, tc := range testCases { got := NormalizeIdentifier(tc.input) - assert.Equal(t, tc.want, got) + if got != tc.want { + t.Errorf("NormalizeIdentifier(%q) = %q, want %q", tc.input, got, tc.want) + } } } diff --git a/tests/gopjrt/collective_test.go b/tests/gopjrt/collective_test.go index 5c8106d..0360bf2 100644 --- a/tests/gopjrt/collective_test.go +++ b/tests/gopjrt/collective_test.go @@ -10,7 +10,6 @@ import ( "github.com/gomlx/gopjrt/pjrt" . "github.com/gomlx/stablehlo" "github.com/gomlx/stablehlo/types/shapes" - "github.com/stretchr/testify/require" ) var flagCollectiveBroadcast = flag.Bool("collective_broadcast", false, "Run collective broadcast test: it is not implemented in PJRT CPU, so it is skipped by default.") @@ -58,9 +57,15 @@ func testCollectiveOps(t *testing.T, client *pjrt.Client) { // Execute expects a flat list of inputs, one for each argument of main(), // mapped to devices in order. e, err := client.Compile().WithStableHLO(program).WithSPMD(numReplicas).Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) + if err != nil { + t.Errorf("failed to compile program: \n%s\nError: %v", program, err) + return + } outputBuffers, err := e.Execute(input0, input1).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) + if err != nil { + t.Errorf("failed to execute program: \n%s\nError: %v", program, err) + return + } // Check outputs: all replicas should have the data from replica 0. want := []FlatAndDims{ @@ -97,9 +102,15 @@ func testCollectiveOps(t *testing.T, client *pjrt.Client) { // Execute expects a flat list of inputs, one for each argument of main(), // mapped to devices in order. e, err := client.Compile().WithStableHLO(program).WithSPMD(numReplicas).Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) + if err != nil { + t.Errorf("failed to compile program: \n%s\nError: %v", program, err) + return + } outputBuffers, err := e.Execute(inputX0, inputX1).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) + if err != nil { + t.Errorf("failed to execute program: \n%s\nError: %v", program, err) + return + } // Check outputs: all replicas should have the sum. want := []FlatAndDims{ @@ -141,9 +152,15 @@ func testCollectiveOps(t *testing.T, client *pjrt.Client) { // Execute expects a flat list of inputs, one for each argument of main(), // mapped to devices in order. e, err := client.Compile().WithStableHLO(program).WithSPMD(numReplicas).Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) + if err != nil { + t.Errorf("failed to compile program: \n%s\nError: %v", program, err) + return + } outputBuffers, err := e.Execute(inputX0, inputY0, inputX1, inputY1).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) + if err != nil { + t.Errorf("failed to execute program: \n%s\nError: %v", program, err) + return + } // Check outputs: all replicas should have the sum. want := []FlatAndDims{ @@ -170,9 +187,15 @@ func testCollectiveOps(t *testing.T, client *pjrt.Client) { []float32{2.0, 20.0}, []int{2}).ToDeviceNum(replicaGroups[0][1]).Done()) e, err := client.Compile().WithStableHLO(program).WithSPMD(numReplicas).Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) + if err != nil { + t.Errorf("failed to compile program: \n%s\nError: %v", program, err) + return + } outputBuffers, err := e.Execute(input0, input1).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) + if err != nil { + t.Errorf("failed to execute program: \n%s\nError: %v", program, err) + return + } want := []FlatAndDims{ {[]float32{1.0, 10.0, 2.0, 20.0}, []int{4}}, @@ -196,9 +219,15 @@ func testCollectiveOps(t *testing.T, client *pjrt.Client) { []float32{10.0, 20.0, 30.0, 40.0}, []int{4}).ToDeviceNum(replicaGroups[0][1]).Done()) e, err := client.Compile().WithStableHLO(program).WithSPMD(numReplicas).Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) + if err != nil { + t.Errorf("failed to compile program: \n%s\nError: %v", program, err) + return + } outputBuffers, err := e.Execute(input0, input1).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) + if err != nil { + t.Errorf("failed to execute program: \n%s\nError: %v", program, err) + return + } want := []FlatAndDims{ {[]float32{1.0, 2.0, 10.0, 20.0}, []int{4}}, @@ -226,9 +255,15 @@ func testCollectiveOps(t *testing.T, client *pjrt.Client) { []float32{2.0, 20.0}, []int{2}).ToDeviceNum(replicaGroups[0][1]).Done()) e, err := client.Compile().WithStableHLO(program).WithSPMD(numReplicas).Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) + if err != nil { + t.Errorf("failed to compile program: \n%s\nError: %v", program, err) + return + } outputBuffers, err := e.Execute(input0, input1).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) + if err != nil { + t.Errorf("failed to execute program: \n%s\nError: %v", program, err) + return + } want := []FlatAndDims{ {[]float32{2.0, 20.0}, []int{2}}, diff --git a/tests/gopjrt/gopjrt_test.go b/tests/gopjrt/gopjrt_test.go index 30f48f0..b0ab3f0 100644 --- a/tests/gopjrt/gopjrt_test.go +++ b/tests/gopjrt/gopjrt_test.go @@ -15,7 +15,6 @@ import ( . "github.com/gomlx/stablehlo" "github.com/gomlx/stablehlo/types" "github.com/gomlx/stablehlo/types/shapes" - "github.com/stretchr/testify/require" "k8s.io/klog/v2" ) @@ -76,14 +75,20 @@ func pjrtClientsIterator(t *testing.T) iter.Seq2[string, *pjrt.Client] { return func(yield func(string, *pjrt.Client) bool) { for _, pluginName := range getPluginNames() { plugin, err := pjrt.GetPlugin(pluginName) - require.NoError(t, err, "failed to load plugin %q", pluginName) + if err != nil { + t.Fatalf("failed to load plugin %q: %v", pluginName, err) + } klog.Infof("Plugin: %s", plugin) client, err := plugin.NewClient(nil) - require.NoError(t, err, "failed to create client for plugin %q", pluginName) + if err != nil { + t.Fatalf("failed to create client for plugin %q: %v", pluginName, err) + } klog.Infof("Client %s (version %s): %d devices", client.Platform(), client.PlatformVersion(), client.NumDevices()) done := yield(pluginName, client) - require.NoError(t, client.Destroy()) + if err := client.Destroy(); err != nil { + t.Fatalf("failed to destroy client: %v", err) + } if done { return } @@ -107,7 +112,9 @@ func iterateClientsAndTest(t *testing.T, testFn func(*testing.T, *pjrt.Client)) // compileAndExecute program with PJRT. All inputs are donated. func compileAndExecute(t *testing.T, client *pjrt.Client, program []byte, inputs ...*pjrt.Buffer) []*pjrt.Buffer { loadedExec, err := client.Compile().WithStableHLO(program).Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) + if err != nil { + t.Fatalf("failed to compile program: \n%s\nError: %v", program, err) + } defer func() { err := loadedExec.Destroy() if err != nil { @@ -115,7 +122,9 @@ func compileAndExecute(t *testing.T, client *pjrt.Client, program []byte, inputs } }() outputBuffers, err := loadedExec.Execute(inputs...).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) + if err != nil { + t.Fatalf("failed to execute program: \n%s\nError: %v", program, err) + } return outputBuffers } @@ -135,21 +144,51 @@ func requireBuffersEqual(t *testing.T, expected []FlatAndDims, got []*pjrt.Buffe } } }() - require.Len(t, got, len(expected)) + if len(got) != len(expected) { + t.Fatalf("expected %d outputs, got %d", len(expected), len(got)) + } for i, b := range got { gotFlat, gotDims, err := b.ToFlatDataAndDimensions() + if err != nil { + t.Fatalf("failed to get buffer contents for output #%d: %v", i, err) + } expectedShape, err := shapes.FromAnyValue(expected[i].Flat) - require.NoErrorf(t, err, "failed to get shape for output #%d: %v", i, expected[i].Flat) + if err != nil { + t.Fatalf("failed to get shape for output #%d: %v\nValue: %v", i, err, expected[i].Flat) + } dtype := expectedShape.DType fmt.Printf("\t - output #%d:\n\t - Got: dims=%v, flat_values=%v\n", i, gotDims, gotFlat) fmt.Printf("\t - Want(%s): dims=%v, flat_values=%v\n", dtype, expected[i].Dims, expected[i].Flat) - require.NoErrorf(t, err, "failed to get buffer contents for output #%d, expected flat value %v", i, expected[i].Flat) - require.Equalf(t, expected[i].Dims, gotDims, "output #%d dims don't match", i) + + if !reflect.DeepEqual(expected[i].Dims, gotDims) { + // Handle nil vs empty slice difference. + if len(expected[i].Dims) != 0 || len(gotDims) != 0 { + t.Errorf("output #%d dims don't match: want %v, got %v", i, expected[i].Dims, gotDims) + } + } + switch dtype { case dtypes.Float64, dtypes.Float32: - require.InDeltaSlicef(t, expected[i].Flat, gotFlat, 1e-4, "output #%d flat values don't match", i) + // For floats use InDelta-like comparison. + expVal := reflect.ValueOf(expected[i].Flat) + gotVal := reflect.ValueOf(gotFlat) + if expVal.Len() != gotVal.Len() { + t.Errorf("output #%d flat values length mismatch: want %d, got %d", i, expVal.Len(), gotVal.Len()) + continue + } + for j := 0; j < expVal.Len(); j++ { + e := expVal.Index(j).Float() + g := gotVal.Index(j).Float() + diff := math.Abs(e - g) + if diff > 1e-4 { + t.Errorf("output #%d flat values don't match at index %d: want %v, got %v (diff %v)", i, j, e, g, diff) + break // Stop after first error to avoid spam + } + } default: - require.Equalf(t, expected[i].Flat, gotFlat, "output #%d flat values don't match", i) + if !reflect.DeepEqual(expected[i].Flat, gotFlat) { + t.Errorf("output #%d flat values don't match: want %v, got %v", i, expected[i].Flat, gotFlat) + } } } } @@ -482,8 +521,12 @@ func testOps(t *testing.T, client *pjrt.Client) { fmt.Printf("%s program:\n%s", t.Name(), withLines(program)) outputs := compileAndExecute(t, client, program) flat, dims, err := outputs[0].ToFlatDataAndDimensions() - require.NoError(t, err) - require.Equal(t, []int{numSamples}, dims) + if err != nil { + t.Fatalf("ToFlatDataAndDimensions error: %v", err) + } + if !reflect.DeepEqual([]int{numSamples}, dims) { + t.Errorf("dims mismatch: want [%d], got %v", numSamples, dims) + } noise := flat.([]uint64) // Count bits in each uint64 var totalBits int @@ -494,8 +537,12 @@ func testOps(t *testing.T, client *pjrt.Client) { expectedBits := 32 * numSamples fmt.Printf("\tgot %d bits set, expected %d\n", totalBits, expectedBits) margin := 2 * numSamples - require.Greater(t, totalBits, expectedBits-margin) - require.Less(t, totalBits, expectedBits+margin) + if totalBits <= expectedBits-margin { + t.Errorf("totalBits %d <= expectedBits-margin %d", totalBits, expectedBits-margin) + } + if totalBits >= expectedBits+margin { + t.Errorf("totalBits %d >= expectedBits+margin %d", totalBits, expectedBits+margin) + } }) } @@ -620,23 +667,35 @@ func testOps(t *testing.T, client *pjrt.Client) { gotDims := must1(outputs[0].Dimensions()) fmt.Printf("\t- FFTForward output dims: %v\n", gotDims) - require.Equal(t, []int{3, 4, 10}, gotDims) + if !reflect.DeepEqual([]int{3, 4, 10}, gotDims) { + t.Errorf("FFTForward dims mismatch: want %v, got %v", []int{3, 4, 10}, gotDims) + } gotDims = must1(outputs[1].Dimensions()) fmt.Printf("\t- FFTInverse output dims: %v\n", gotDims) - require.Equal(t, []int{3, 4, 10}, gotDims) + if !reflect.DeepEqual([]int{3, 4, 10}, gotDims) { + t.Errorf("FFTInverse dims mismatch: want %v, got %v", []int{3, 4, 10}, gotDims) + } gotDims = must1(outputs[2].Dimensions()) gotDType := must1(outputs[2].DType()) fmt.Printf("\t- FFTForwardReal output dtype %s, dims: %v\n", gotDType, gotDims) - require.Equal(t, []int{3, 4, 10/2 + 1}, gotDims) - require.Equal(t, dtypes.Complex64, gotDType) + if !reflect.DeepEqual([]int{3, 4, 10/2 + 1}, gotDims) { + t.Errorf("FFTForwardReal dims mismatch: want %v, got %v", []int{3, 4, 10/2 + 1}, gotDims) + } + if gotDType != dtypes.Complex64 { + t.Errorf("FFTForwardReal dtype mismatch: want %v, got %v", dtypes.Complex64, gotDType) + } gotDims = must1(outputs[3].Dimensions()) gotDType = must1(outputs[3].DType()) fmt.Printf("\t- FFTInverseReal output dtype %s, dims: %v\n", gotDType, gotDims) - require.Equal(t, []int{3, 4, 2 * (10 - 1)}, gotDims) - require.Equal(t, dtypes.Float32, gotDType) + if !reflect.DeepEqual([]int{3, 4, 2 * (10 - 1)}, gotDims) { + t.Errorf("FFTInverseReal dims mismatch: want %v, got %v", []int{3, 4, 2 * (10 - 1)}, gotDims) + } + if gotDType != dtypes.Float32 { + t.Errorf("FFTInverseReal dtype mismatch: want %v, got %v", dtypes.Float32, gotDType) + } }) t.Run("ReduceWindow", func(t *testing.T) { @@ -784,7 +843,9 @@ func testOps(t *testing.T, client *pjrt.Client) { scale := must1(fn.ConstantFromFlatAndDimensions([]float32{1, 2, 3}, 3)) offset := must1(fn.ConstantFromFlatAndDimensions([]float32{10, 100, 1000}, 3)) xNorm, batchMean, batchVariance, err := BatchNormTraining(x, scale, offset, 1e-7, -1) - require.NoError(t, err) + if err != nil { + t.Fatalf("BatchNormTraining error: %v", err) + } must(fn.Return(xNorm, batchMean, batchVariance)) program := must1(builder.Build()) fmt.Printf("%s program:\n%s", t.Name(), withLines(program)) @@ -815,7 +876,9 @@ func testOps(t *testing.T, client *pjrt.Client) { gradOutput := must1(fn.ConstantFromScalar(float32(1))) gradOutput = must1(BroadcastInDim(gradOutput, shapes.Make(dtypes.F32, 7, 3), nil)) gradX, gradScale, gradOffset, err := BatchNormGradient(x, scale, mean, variance, gradOutput, 1e-7, -1) - require.NoError(t, err) + if err != nil { + t.Fatalf("BatchNormGradient error: %v", err) + } must(fn.Return(gradX, gradScale, gradOffset)) program := must1(builder.Build()) fmt.Printf("%s program:\n%s", t.Name(), withLines(program)) @@ -1029,7 +1092,9 @@ func testUnaryOps(t *testing.T, client *pjrt.Client) { shape := shapes.Make(dtype) fn := builder.Main() arg, err := fn.Input(shape) - require.NoError(t, err) + if err != nil { + t.Fatalf("fn.Input error: %v", err) + } result := must1(op(arg)) must(fn.Return(result)) program := must1(builder.Build()) @@ -1166,16 +1231,24 @@ func testConstants(t *testing.T, client *pjrt.Client) { builder := New(t.Name()) fn := builder.Main() c, err := fn.ConstantFromScalar(scalar) - require.NoError(t, err) + if err != nil { + t.Fatalf("ConstantFromScalar error: %v", err) + } must(fn.Return(c)) program := must1(builder.Build()) fmt.Printf("%s program:\n%s", t.Name(), withLines(program)) output := compileAndExecute(t, client, program)[0] gotFlat, gotDim, err := output.ToFlatDataAndDimensions() - require.NoError(t, err) - require.Len(t, gotDim, 0) + if err != nil { + t.Fatalf("ToFlatDataAndDimensions error: %v", err) + } + if len(gotDim) != 0 { + t.Errorf("gotDim len %d, want 0", len(gotDim)) + } gotScalar := reflect.ValueOf(gotFlat).Index(0).Interface() - require.Equal(t, scalar, gotScalar) + if !reflect.DeepEqual(scalar, gotScalar) { + t.Errorf("gotScalar %v, want %v", gotScalar, scalar) + } } t.Run("float32", func(t *testing.T) { testScalar(t, float32(3.0)) }) @@ -1191,15 +1264,26 @@ func testConstants(t *testing.T, client *pjrt.Client) { builder := New(t.Name()) fn := builder.Main() c, err := fn.ConstantFromFlatAndDimensions(flat, dimensions...) - require.NoError(t, err) + if err != nil { + t.Fatalf("ConstantFromFlatAndDimensions error: %v", err) + } must(fn.Return(c)) program := must1(builder.Build()) fmt.Printf("%s program:\n%s", t.Name(), withLines(program)) output := compileAndExecute(t, client, program)[0] gotFlat, gotDims, err := output.ToFlatDataAndDimensions() - require.NoError(t, err) - require.Equal(t, dimensions, gotDims) - require.Equal(t, flat, gotFlat) + if err != nil { + t.Fatalf("ToFlatDataAndDimensions error: %v", err) + } + if !reflect.DeepEqual(dimensions, gotDims) { + // Handle nil vs empty slice for dims + if len(dimensions) != 0 || len(gotDims) != 0 { + t.Errorf("dims mismatch: want %v, got %v", dimensions, gotDims) + } + } + if !reflect.DeepEqual(flat, gotFlat) { + t.Errorf("flat data mismatch: want %v, got %v", flat, gotFlat) + } } t.Run("0D-int8", func(t *testing.T) { testTensor(t, []int8{-3}) }) diff --git a/tests/gopjrt/shardy_test.go b/tests/gopjrt/shardy_test.go index fbe941a..daaab8b 100644 --- a/tests/gopjrt/shardy_test.go +++ b/tests/gopjrt/shardy_test.go @@ -9,7 +9,6 @@ import ( "github.com/gomlx/stablehlo" "github.com/gomlx/stablehlo/types/shapes" "github.com/gomlx/stablehlo/types/shardy" - "github.com/stretchr/testify/require" ) func TestShardy(t *testing.T) { @@ -24,7 +23,9 @@ func shardyCompileAndExecute(t *testing.T, client *pjrt.Client, program []byte, WithShardy(len(deviceAssignment)). WithDeviceAssignment(deviceAssignment). Done() - require.NoErrorf(t, err, "failed to compile program: \n%s", program) + if err != nil { + t.Fatalf("failed to compile program: \n%s\nError: %v", program, err) + } defer func() { err := loadedExec.Destroy() if err != nil { @@ -32,7 +33,9 @@ func shardyCompileAndExecute(t *testing.T, client *pjrt.Client, program []byte, } }() outputBuffers, err := loadedExec.Execute(inputs...).DonateAll().Done() - require.NoErrorf(t, err, "failed to execute program: \n%s", program) + if err != nil { + t.Fatalf("failed to execute program: \n%s\nError: %v", program, err) + } return outputBuffers } diff --git a/types/shapes/shapes_test.go b/types/shapes/shapes_test.go index 3729f1b..3250fdb 100644 --- a/types/shapes/shapes_test.go +++ b/types/shapes/shapes_test.go @@ -17,10 +17,10 @@ package shapes import ( + "reflect" "testing" "github.com/gomlx/gopjrt/dtypes" - "github.com/stretchr/testify/require" ) func TestCastAsDType(t *testing.T) { @@ -28,60 +28,140 @@ func TestCastAsDType(t *testing.T) { { want := [][]float32{{1, 2}, {3, 4}, {5, 6}} got := CastAsDType(value, dtypes.Float32) - require.Equal(t, want, got) + if !reflect.DeepEqual(want, got) { + t.Errorf("CastAsDType(..., Float32) = %v, want %v", got, want) + } } { want := [][]complex64{{1, 2}, {3, 4}, {5, 6}} got := CastAsDType(value, dtypes.Complex64) - require.Equal(t, want, got) + if !reflect.DeepEqual(want, got) { + t.Errorf("CastAsDType(..., Complex64) = %v, want %v", got, want) + } } } func TestShape(t *testing.T) { invalidShape := Invalid() - require.False(t, invalidShape.Ok()) + if invalidShape.Ok() { + t.Error("Invalid().Ok() should be false") + } shape0 := Make(dtypes.Float64) - require.True(t, shape0.Ok()) - require.True(t, shape0.IsScalar()) - require.False(t, shape0.IsTuple()) - require.Equal(t, 0, shape0.Rank()) - require.Len(t, shape0.Dimensions, 0) - require.Equal(t, 1, shape0.Size()) - require.Equal(t, 8, int(shape0.Memory())) + if !shape0.Ok() { + t.Error("shape0.Ok() should be true") + } + if !shape0.IsScalar() { + t.Error("shape0.IsScalar() should be true") + } + if shape0.IsTuple() { + t.Error("shape0.IsTuple() should be false") + } + if shape0.Rank() != 0 { + t.Errorf("shape0.Rank() = %d, want 0", shape0.Rank()) + } + if len(shape0.Dimensions) != 0 { + t.Errorf("len(shape0.Dimensions) = %d, want 0", len(shape0.Dimensions)) + } + if shape0.Size() != 1 { + t.Errorf("shape0.Size() = %d, want 1", shape0.Size()) + } + if int(shape0.Memory()) != 8 { + t.Errorf("shape0.Memory() = %d, want 8", int(shape0.Memory())) + } shape1 := Make(dtypes.Float32, 4, 3, 2) - require.True(t, shape1.Ok()) - require.False(t, shape1.IsScalar()) - require.False(t, shape1.IsTuple()) - require.Equal(t, 3, shape1.Rank()) - require.Len(t, shape1.Dimensions, 3) - require.Equal(t, 4*3*2, shape1.Size()) - require.Equal(t, 4*4*3*2, int(shape1.Memory())) + if !shape1.Ok() { + t.Error("shape1.Ok() should be true") + } + if shape1.IsScalar() { + t.Error("shape1.IsScalar() should be false") + } + if shape1.IsTuple() { + t.Error("shape1.IsTuple() should be false") + } + if shape1.Rank() != 3 { + t.Errorf("shape1.Rank() = %d, want 3", shape1.Rank()) + } + if len(shape1.Dimensions) != 3 { + t.Errorf("len(shape1.Dimensions) = %d, want 3", len(shape1.Dimensions)) + } + if shape1.Size() != 4*3*2 { + t.Errorf("shape1.Size() = %d, want %d", shape1.Size(), 4*3*2) + } + if int(shape1.Memory()) != 4*4*3*2 { + t.Errorf("shape1.Memory() = %d, want %d", int(shape1.Memory()), 4*4*3*2) + } +} + +func panics(t *testing.T, f func()) { + t.Helper() + defer func() { + if r := recover(); r == nil { + t.Error("expected panic, but code did not panic") + } + }() + f() +} + +func notPanics(t *testing.T, f func()) { + t.Helper() + defer func() { + if r := recover(); r != nil { + t.Errorf("expected no panic, but code panicked: %v", r) + } + }() + f() } func TestDim(t *testing.T) { shape := Make(dtypes.Float32, 4, 3, 2) - require.Equal(t, 4, shape.Dim(0)) - require.Equal(t, 3, shape.Dim(1)) - require.Equal(t, 2, shape.Dim(2)) - require.Equal(t, 4, shape.Dim(-3)) - require.Equal(t, 3, shape.Dim(-2)) - require.Equal(t, 2, shape.Dim(-1)) - require.Panics(t, func() { _ = shape.Dim(3) }) - require.Panics(t, func() { _ = shape.Dim(-4) }) + if d := shape.Dim(0); d != 4 { + t.Errorf("shape.Dim(0) = %d, want 4", d) + } + if d := shape.Dim(1); d != 3 { + t.Errorf("shape.Dim(1) = %d, want 3", d) + } + if d := shape.Dim(2); d != 2 { + t.Errorf("shape.Dim(2) = %d, want 2", d) + } + if d := shape.Dim(-3); d != 4 { + t.Errorf("shape.Dim(-3) = %d, want 4", d) + } + if d := shape.Dim(-2); d != 3 { + t.Errorf("shape.Dim(-2) = %d, want 3", d) + } + if d := shape.Dim(-1); d != 2 { + t.Errorf("shape.Dim(-1) = %d, want 2", d) + } + panics(t, func() { _ = shape.Dim(3) }) + panics(t, func() { _ = shape.Dim(-4) }) } func TestFromAnyValue(t *testing.T) { shape, err := FromAnyValue([]int32{1, 2, 3}) - require.NoError(t, err) - require.NotPanics(t, func() { shape.Assert(dtypes.Int32, 3) }) + if err != nil { + t.Fatalf("FromAnyValue failed: %v", err) + } + notPanics(t, func() { + if err := shape.Check(dtypes.Int32, 3); err != nil { + panic(err) + } + }) shape, err = FromAnyValue([][][]complex64{{{1, 2, -3}, {3, 4 + 2i, -7 - 1i}}}) - require.NoError(t, err) - require.NotPanics(t, func() { shape.Assert(dtypes.Complex64, 1, 2, 3) }) + if err != nil { + t.Fatalf("FromAnyValue failed: %v", err) + } + notPanics(t, func() { + if err := shape.Check(dtypes.Complex64, 1, 2, 3); err != nil { + panic(err) + } + }) // Irregular shape is not accepted: shape, err = FromAnyValue([][]float32{{1, 2, 3}, {4, 5}}) - require.Errorf(t, err, "irregular shape should have returned an error, instead got shape %s", shape) + if err == nil { + t.Errorf("irregular shape should have returned an error, instead got shape %s", shape) + } } diff --git a/types/shapes/stablehlo_test.go b/types/shapes/stablehlo_test.go index 88ad8ba..d63b5c2 100644 --- a/types/shapes/stablehlo_test.go +++ b/types/shapes/stablehlo_test.go @@ -4,14 +4,17 @@ import ( "testing" "github.com/gomlx/gopjrt/dtypes" - "github.com/stretchr/testify/require" ) func TestToStableHLO(t *testing.T) { shape := Make(dtypes.Float32, 1, 10) - require.Equal(t, "tensor<1x10xf32>", shape.ToStableHLO()) + if got := shape.ToStableHLO(); got != "tensor<1x10xf32>" { + t.Errorf("ToStableHLO() = %q, want %q", got, "tensor<1x10xf32>") + } // Test scalar. shape = Make(dtypes.Int32) - require.Equal(t, "tensor", shape.ToStableHLO()) + if got := shape.ToStableHLO(); got != "tensor" { + t.Errorf("ToStableHLO() = %q, want %q", got, "tensor") + } } diff --git a/types/shardy/devicemesh_test.go b/types/shardy/devicemesh_test.go index f8dda8a..2c2e865 100644 --- a/types/shardy/devicemesh_test.go +++ b/types/shardy/devicemesh_test.go @@ -1,11 +1,11 @@ package shardy_test import ( + "reflect" + "strings" "testing" "github.com/gomlx/stablehlo/types/shardy" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestDeviceMesh(t *testing.T) { @@ -55,11 +55,21 @@ func TestDeviceMesh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) - require.NoError(t, err) - assert.NotNil(t, mesh) - assert.Equal(t, tt.wantRank, mesh.Rank()) - assert.Equal(t, tt.wantNum, mesh.NumDevices()) - assert.Equal(t, tt.wantStableHLO, mesh.ToStableHLO()) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } + if mesh == nil { + t.Fatal("NewDeviceMesh() returned nil") + } + if got := mesh.Rank(); got != tt.wantRank { + t.Errorf("Rank() = %d, want %d", got, tt.wantRank) + } + if got := mesh.NumDevices(); got != tt.wantNum { + t.Errorf("NumDevices() = %d, want %d", got, tt.wantNum) + } + if got := mesh.ToStableHLO(); got != tt.wantStableHLO { + t.Errorf("ToStableHLO() = %q, want %q", got, tt.wantStableHLO) + } }) } }) @@ -100,40 +110,60 @@ func TestDeviceMesh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) - require.Error(t, err) - assert.Nil(t, mesh) - assert.Contains(t, err.Error(), tt.wantErr) + if err == nil { + t.Error("NewDeviceMesh() expected error, got nil") + } + if mesh != nil { + t.Error("NewDeviceMesh() expected nil mesh on error") + } + if err != nil && !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("NewDeviceMesh() error = %q, want substring %q", err.Error(), tt.wantErr) + } }) } }) t.Run("AxesNames", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } axisNames := mesh.AxesNames() - assert.Equal(t, []string{"x", "y"}, axisNames) + if !reflect.DeepEqual(axisNames, []string{"x", "y"}) { + t.Errorf("AxesNames() = %v, want %v", axisNames, []string{"x", "y"}) + } // Verify it returns a copy axisNames[0] = "modified" - assert.Equal(t, []string{"x", "y"}, mesh.AxesNames()) + if !reflect.DeepEqual(mesh.AxesNames(), []string{"x", "y"}) { + t.Errorf("AxesNames() modified original, want %v", []string{"x", "y"}) + } }) t.Run("Shape", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } axesSizes := mesh.AxesSizes() - assert.Equal(t, []int{2, 4}, axesSizes) + if !reflect.DeepEqual(axesSizes, []int{2, 4}) { + t.Errorf("AxesSizes() = %v, want %v", axesSizes, []int{2, 4}) + } // Verify it returns a copy axesSizes[0] = 99 - assert.Equal(t, []int{2, 4}, mesh.AxesSizes()) + if !reflect.DeepEqual(mesh.AxesSizes(), []int{2, 4}) { + t.Errorf("AxesSizes() modified original, want %v", []int{2, 4}) + } }) t.Run("AxisSize", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } tests := []struct { name string @@ -165,11 +195,19 @@ func TestDeviceMesh(t *testing.T) { t.Run(tt.name, func(t *testing.T) { size, err := mesh.AxisSize(tt.axisName) if tt.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), "not found") + if err == nil { + t.Error("AxisSize() expected error, got nil") + } + if err != nil && !strings.Contains(err.Error(), "not found") { + t.Errorf("AxisSize() error = %q, want substring %q", err.Error(), "not found") + } } else { - require.NoError(t, err) - assert.Equal(t, tt.wantSize, size) + if err != nil { + t.Errorf("AxisSize() error = %v", err) + } + if size != tt.wantSize { + t.Errorf("AxisSize() = %d, want %d", size, tt.wantSize) + } } }) } @@ -199,15 +237,21 @@ func TestDeviceMesh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", tt.shape, tt.axisNames) - require.NoError(t, err) - assert.Equal(t, tt.want, mesh.String()) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } + if got := mesh.String(); got != tt.want { + t.Errorf("String() = %q, want %q", got, tt.want) + } }) } }) t.Run("SetDeviceAssignment_Valid", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } tests := []struct { name string @@ -230,14 +274,18 @@ func TestDeviceMesh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := mesh.SetLogicalDeviceAssignment(tt.devices...) - require.NoErrorf(t, err, "failed test %q", tt.name) + if err != nil { + t.Errorf("SetLogicalDeviceAssignment() error = %v", err) + } }) } }) t.Run("SetDeviceAssignment_Errors", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } tests := []struct { name string @@ -264,111 +312,184 @@ func TestDeviceMesh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := mesh.SetLogicalDeviceAssignment(tt.devices...) - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) + if err == nil { + t.Error("SetLogicalDeviceAssignment() expected error, got nil") + } + if err != nil && !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("SetLogicalDeviceAssignment() error = %q, want substring %q", err.Error(), tt.wantErr) + } }) } }) t.Run("DeviceToMesh_2D", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 4}, []string{"x", "y"}) - require.NoError(t, err) - require.Equal(t, 8, mesh.NumDevices()) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } + if got := mesh.NumDevices(); got != 8 { + t.Errorf("NumDevices() = %d, want 8", got) + } }) t.Run("DeviceToMesh_3D", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) - require.NoError(t, err) - require.Equal(t, 8, mesh.NumDevices()) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } + if got := mesh.NumDevices(); got != 8 { + t.Errorf("NumDevices() = %d, want 8", got) + } }) t.Run("DeviceToMesh_WithCustomMapping", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } err = mesh.SetLogicalDeviceAssignment(3, 2, 1, 0) - require.NoError(t, err) - require.Equal(t, 4, mesh.NumDevices()) + if err != nil { + t.Fatalf("SetLogicalDeviceAssignment() error = %v", err) + } + if got := mesh.NumDevices(); got != 4 { + t.Errorf("NumDevices() = %d, want 4", got) + } err = mesh.SetLogicalDeviceAssignment(4, 2, 1, 0) - require.Error(t, err) + if err == nil { + t.Error("SetLogicalDeviceAssignment() expected error for out of range device, got nil") + } }) t.Run("ComputeReplicaGroups", func(t *testing.T) { t.Run("2D mesh batch groups", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } // Example from comments: m.ComputeReplicaGroups([]string{"batch"}) -> [][]int{{0, 2}, {1, 3}} groups, err := mesh.ComputeReplicaGroups([]string{"batch"}) - require.NoError(t, err) - assert.Equal(t, [][]int{{0, 2}, {1, 3}}, groups) + if err != nil { + t.Fatalf("ComputeReplicaGroups() error = %v", err) + } + expected := [][]int{{0, 2}, {1, 3}} + if !reflect.DeepEqual(groups, expected) { + t.Errorf("ComputeReplicaGroups() = %v, want %v", groups, expected) + } }) t.Run("2D mesh data groups", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } // Example from comments: m.ComputeReplicaGroups([]string{"data"}) -> [][]int{{0, 1}, {2, 3}} groups, err := mesh.ComputeReplicaGroups([]string{"data"}) - require.NoError(t, err) - assert.Equal(t, [][]int{{0, 1}, {2, 3}}, groups) + if err != nil { + t.Fatalf("ComputeReplicaGroups() error = %v", err) + } + expected := [][]int{{0, 1}, {2, 3}} + if !reflect.DeepEqual(groups, expected) { + t.Errorf("ComputeReplicaGroups() = %v, want %v", groups, expected) + } }) t.Run("2D mesh global groups", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } // Example from comments: m.ComputeReplicaGroups([]string{"batch", "data"}) -> [][]int{{0, 1, 2, 3}} groups, err := mesh.ComputeReplicaGroups([]string{"batch", "data"}) - require.NoError(t, err) - assert.Equal(t, [][]int{{0, 1, 2, 3}}, groups) + if err != nil { + t.Fatalf("ComputeReplicaGroups() error = %v", err) + } + expected := [][]int{{0, 1, 2, 3}} + if !reflect.DeepEqual(groups, expected) { + t.Errorf("ComputeReplicaGroups() = %v, want %v", groups, expected) + } }) t.Run("1D mesh", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{4}, []string{"replica"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } groups, err := mesh.ComputeReplicaGroups([]string{"replica"}) - require.NoError(t, err) - assert.Equal(t, [][]int{{0, 1, 2, 3}}, groups) + if err != nil { + t.Fatalf("ComputeReplicaGroups() error = %v", err) + } + expected := [][]int{{0, 1, 2, 3}} + if !reflect.DeepEqual(groups, expected) { + t.Errorf("ComputeReplicaGroups() = %v, want %v", groups, expected) + } }) t.Run("3D mesh single axis", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } // Groups along x axis: should split by y and z groups, err := mesh.ComputeReplicaGroups([]string{"x"}) - require.NoError(t, err) - assert.Equal(t, [][]int{{0, 4}, {1, 5}, {2, 6}, {3, 7}}, groups) + if err != nil { + t.Fatalf("ComputeReplicaGroups() error = %v", err) + } + expected := [][]int{{0, 4}, {1, 5}, {2, 6}, {3, 7}} + if !reflect.DeepEqual(groups, expected) { + t.Errorf("ComputeReplicaGroups() = %v, want %v", groups, expected) + } }) t.Run("3D mesh two axes", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2, 2}, []string{"x", "y", "z"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } // Groups along x and y axes: should split by z groups, err := mesh.ComputeReplicaGroups([]string{"x", "y"}) - require.NoError(t, err) - assert.Equal(t, [][]int{{0, 2, 4, 6}, {1, 3, 5, 7}}, groups) + if err != nil { + t.Fatalf("ComputeReplicaGroups() error = %v", err) + } + expected := [][]int{{0, 2, 4, 6}, {1, 3, 5, 7}} + if !reflect.DeepEqual(groups, expected) { + t.Errorf("ComputeReplicaGroups() = %v, want %v", groups, expected) + } }) t.Run("empty axes list", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } // Empty axes list: each device is its own group groups, err := mesh.ComputeReplicaGroups([]string{}) - require.NoError(t, err) - assert.Equal(t, [][]int{{0}, {1}, {2}, {3}}, groups) + if err != nil { + t.Fatalf("ComputeReplicaGroups() error = %v", err) + } + expected := [][]int{{0}, {1}, {2}, {3}} + if !reflect.DeepEqual(groups, expected) { + t.Errorf("ComputeReplicaGroups() = %v, want %v", groups, expected) + } }) t.Run("non-existent axis", func(t *testing.T) { mesh, err := shardy.NewDeviceMesh("mesh", []int{2, 2}, []string{"batch", "data"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } // A non-existent axis should return an error. _, err = mesh.ComputeReplicaGroups([]string{"nonexistent"}) - require.Error(t, err) + if err == nil { + t.Error("ComputeReplicaGroups() expected error, got nil") + } }) }) } diff --git a/types/shardy/shardingspec_test.go b/types/shardy/shardingspec_test.go index 5e97424..1029ae5 100644 --- a/types/shardy/shardingspec_test.go +++ b/types/shardy/shardingspec_test.go @@ -2,13 +2,13 @@ package shardy import ( "testing" - - "github.com/stretchr/testify/require" ) func TestShardSpec_ToStableHLO(t *testing.T) { mesh, err := NewDeviceMesh("test_mesh", []int{4, 2}, []string{"z", "a"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } testCases := []struct { name string spec *ShardingSpec @@ -48,14 +48,18 @@ func TestShardSpec_ToStableHLO(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.expected, tc.spec.ToStableHLO()) + if got := tc.spec.ToStableHLO(); got != tc.expected { + t.Errorf("ToStableHLO() = %q, want %q", got, tc.expected) + } }) } } func TestShardSpec_Validate(t *testing.T) { mesh, err := NewDeviceMesh("test_mesh", []int{2, 8}, []string{"z", "a"}) - require.NoError(t, err) + if err != nil { + t.Fatalf("NewDeviceMesh() error = %v", err) + } testCases := []struct { name string spec *ShardingSpec @@ -107,9 +111,13 @@ func TestShardSpec_Validate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { err := tc.spec.Validate() if tc.expectError { - require.Error(t, err) + if err == nil { + t.Error("Validate() expected error, got nil") + } } else { - require.NoError(t, err) + if err != nil { + t.Errorf("Validate() error = %v", err) + } } }) }