diff --git a/go.mod b/go.mod index 2ccfa06..50befda 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ toolchain go1.24.6 require ( github.com/gomlx/gopjrt v0.10.0-rc0 - github.com/janpfeifer/must v0.2.0 github.com/pkg/errors v0.9.1 github.com/x448/float16 v0.8.4 k8s.io/klog/v2 v2.130.1 diff --git a/go.sum b/go.sum index 50009f0..de165ce 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/janpfeifer/go-benchmarks v0.1.1 h1:gLLy07/JrOKSnMWeUxSnjTdhkglgmrNR2IBDnR4kRqw= github.com/janpfeifer/go-benchmarks v0.1.1/go.mod h1:5AagXCOUzevvmYFQalcgoa4oWPyH1IkZNckolGWfiSM= -github.com/janpfeifer/must v0.2.0 h1:yWy1CE5gtk1i2ICBvqAcMMXrCMqil9CJPkc7x81fRdQ= -github.com/janpfeifer/must v0.2.0/go.mod h1:S6c5Yg/YSMR43cJw4zhIq7HFMci90a7kPY9XA4c8UIs= github.com/pascaldekloe/name v1.0.0 h1:n7LKFgHixETzxpRv2R77YgPUFo85QHGZKrdaYm7eY5U= github.com/pascaldekloe/name v1.0.0/go.mod h1:Z//MfYJnH4jVpQ9wkclwu2I2MkHmXTlT9wR5UZScttM= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/internal/cmd/ops_generator/binary_ops.go b/internal/cmd/ops_generator/binary_ops.go index c80cbf1..7cd46ea 100644 --- a/internal/cmd/ops_generator/binary_ops.go +++ b/internal/cmd/ops_generator/binary_ops.go @@ -9,7 +9,6 @@ import ( "github.com/gomlx/stablehlo/internal/utils" "github.com/gomlx/stablehlo/shapeinference" - "github.com/janpfeifer/must" ) const ( @@ -51,11 +50,11 @@ func GenerateBinaryOps() { } fileName := binaryOpsFile - f := must.M1(os.Create(fileName)) - must.M(binaryOpsTemplate.Execute(f, data)) - must.M(f.Close()) + f := must1(os.Create(fileName)) + must(binaryOpsTemplate.Execute(f, data)) + must(f.Close()) cmd := exec.Command("gofmt", "-w", fileName) - must.M(cmd.Run()) - fmt.Printf("✅ Successfully generated %s\n", path.Join(must.M1(os.Getwd()), fileName)) + must(cmd.Run()) + fmt.Printf("✅ Successfully generated %s\n", path.Join(must1(os.Getwd()), fileName)) } diff --git a/internal/cmd/ops_generator/main.go b/internal/cmd/ops_generator/main.go index 32c025c..2af8349 100644 --- a/internal/cmd/ops_generator/main.go +++ b/internal/cmd/ops_generator/main.go @@ -1,6 +1,19 @@ package main +import "log" + func main() { GenerateBinaryOps() GenerateUnaryOps() } + +func must(err error) { + if err != nil { + log.Fatalf("Failed: %+v", err) + } +} + +func must1[T any](value T, err error) T { + must(err) + return value +} diff --git a/internal/cmd/ops_generator/unary_ops.go b/internal/cmd/ops_generator/unary_ops.go index 161f2a6..b3e56d3 100644 --- a/internal/cmd/ops_generator/unary_ops.go +++ b/internal/cmd/ops_generator/unary_ops.go @@ -9,7 +9,6 @@ import ( "github.com/gomlx/stablehlo/internal/utils" "github.com/gomlx/stablehlo/shapeinference" - "github.com/janpfeifer/must" ) const ( @@ -52,11 +51,11 @@ func GenerateUnaryOps() { } fileName := unaryOpsFile - f := must.M1(os.Create(fileName)) - must.M(unaryOpsTemplate.Execute(f, data)) - must.M(f.Close()) + f := must1(os.Create(fileName)) + must(unaryOpsTemplate.Execute(f, data)) + must(f.Close()) cmd := exec.Command("gofmt", "-w", fileName) - must.M(cmd.Run()) - fmt.Printf("✅ Successfully generated %s\n", path.Join(must.M1(os.Getwd()), fileName)) + must(cmd.Run()) + fmt.Printf("✅ Successfully generated %s\n", path.Join(must1(os.Getwd()), fileName)) }