Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,18 @@ require (
github.com/gomlx/gopjrt v0.10.0-rc0
github.com/janpfeifer/must v0.2.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1
github.com/x448/float16 v0.8.4
k8s.io/klog/v2 v2.130.1
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dmarkham/enumer v1.6.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/pascaldekloe/name v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/mod v0.27.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/tools v0.36.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

tool github.com/dmarkham/enumer
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ github.com/dmarkham/enumer v1.6.1 h1:aSc9awYtZL07TUueWs40QcHtxTvHTAwG0EqrNsK45w4
github.com/dmarkham/enumer v1.6.1/go.mod h1:yixql+kDDQRYqcuBM2n9Vlt7NoT9ixgXhaXry8vmRg8=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b h1:2jh5hJmb6YzGWL+rKUDYUpMdcmeeA+jDkKpfmevxl8g=
github.com/gomlx/gopjrt v0.9.2-0.20251113071311-1488e6396f1b/go.mod h1:c8UENVGnxIDdihEL5HinlAdgR7RxTbEPLBppiMQF1ew=
github.com/gomlx/gopjrt v0.10.0-rc0 h1:EpF+JJYl3AUvU5ToKSfsuFnSPBxkPjbor93Ziak7OGA=
github.com/gomlx/gopjrt v0.10.0-rc0/go.mod h1:c8UENVGnxIDdihEL5HinlAdgR7RxTbEPLBppiMQF1ew=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
Expand Down Expand Up @@ -34,8 +32,6 @@ golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
Expand Down
70 changes: 51 additions & 19 deletions internal/utils/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,70 @@ package utils

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSet(t *testing.T) {
// Sets are created empty.
s := MakeSet[int](10)
assert.Len(t, s, 0)
if len(s) != 0 {
t.Errorf("expected len 0, got %d", len(s))
}

// Check inserting and recovery.
s.Insert(3, 7)
assert.Len(t, s, 2)
assert.True(t, s.Has(3))
assert.True(t, s.Has(7))
assert.False(t, s.Has(5))
if len(s) != 2 {
t.Errorf("expected len 2, got %d", len(s))
}
if !s.Has(3) {
t.Errorf("expected s.Has(3) to be true")
}
if !s.Has(7) {
t.Errorf("expected s.Has(7) to be true")
}
if s.Has(5) {
t.Errorf("expected s.Has(5) to be false")
}

s2 := SetWith(5, 7)
assert.Len(t, s2, 2)
assert.True(t, s2.Has(5))
assert.True(t, s2.Has(7))
assert.False(t, s2.Has(3))
if len(s2) != 2 {
t.Errorf("expected len 2, got %d", len(s2))
}
if !s2.Has(5) {
t.Errorf("expected s2.Has(5) to be true")
}
if !s2.Has(7) {
t.Errorf("expected s2.Has(7) to be true")
}
if s2.Has(3) {
t.Errorf("expected s2.Has(3) to be false")
}

s3 := s.Sub(s2)
assert.Len(t, s3, 1)
assert.True(t, s3.Has(3))
if len(s3) != 1 {
t.Errorf("expected len 1, got %d", len(s3))
}
if !s3.Has(3) {
t.Errorf("expected s3.Has(3) to be true")
}

delete(s, 7)
assert.Len(t, s, 1)
assert.True(t, s.Has(3))
assert.False(t, s.Has(7))
assert.True(t, s.Equal(s3))
assert.False(t, s.Equal(s2))
if len(s) != 1 {
t.Errorf("expected len 1, got %d", len(s))
}
if !s.Has(3) {
t.Errorf("expected s.Has(3) to be true")
}
if s.Has(7) {
t.Errorf("expected s.Has(7) to be false")
}
if !s.Equal(s3) {
t.Errorf("expected s.Equal(s3) to be true")
}
if s.Equal(s2) {
t.Errorf("expected s.Equal(s2) to be false")
}
s4 := SetWith(-3)
assert.False(t, s.Equal(s4))
if s.Equal(s4) {
t.Errorf("expected s.Equal(s4) to be false")
}
}
18 changes: 13 additions & 5 deletions shapeinference/convolve_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package shapeinference

import (
"strings"
"testing"

"github.com/gomlx/stablehlo/types/shapes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestConvolve(t *testing.T) {
Expand Down Expand Up @@ -238,11 +237,20 @@ func TestConvolve(t *testing.T) {
tc.outputBatch, tc.outputChannels, tc.outputSpatial,
tc.channelGroupCount, tc.batchGroupCount)
if tc.expectedError != "" {
require.ErrorContains(t, err, tc.expectedError)
if err == nil {
t.Fatalf("expected error containing %q, got nil", tc.expectedError)
}
if !strings.Contains(err.Error(), tc.expectedError) {
t.Fatalf("expected error containing %q, got %q", tc.expectedError, err.Error())
}
return
}
require.NoError(t, err)
assert.Equal(t, tc.output, output)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !tc.output.Equal(output) {
t.Errorf("expected output %v, got %v", tc.output, output)
}
})
}
}
Loading