Skip to content

Commit d667f28

Browse files
refactor: training fixes and improvements
1 parent a7cfae9 commit d667f28

19 files changed

+394
-110
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#--- dockerfile to test hugot ---
22

3-
ARG GO_VERSION=1.23.5
3+
ARG GO_VERSION=1.24.0
44
ARG ONNXRUNTIME_VERSION=1.20.1
55
ARG BUILD_PLATFORM=linux/amd64
66

cmd/main.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ var runCommand = &cli.Command{
118118
} else {
119119
homeDir, err := os.UserHomeDir()
120120
if err != nil {
121-
if exists, err := util.FileSystem.Exists(ctx.Context, path.Join(homeDir, "lib", "hugot", "onnxruntime.so")); err != nil && exists {
121+
if exists, err := util.FileExists(path.Join(homeDir, "lib", "hugot", "onnxruntime.so")); err != nil && exists {
122122
opts = append(opts, options.WithOnnxLibraryPath(path.Join(homeDir, "lib", "hugot", "onnxruntime.so")))
123123
}
124124
}
@@ -137,14 +137,14 @@ var runCommand = &cli.Command{
137137
}()
138138

139139
// is the model a full path to a model
140-
ok, err := util.FileSystem.Exists(ctx.Context, modelPath)
140+
ok, err := util.FileExists(modelPath)
141141
if err != nil {
142142
return err
143143
}
144144
if !ok {
145145
// is the model the name of a model previously downloaded
146146
downloadedModelName := strings.Replace(modelPath, "/", "_", -1)
147-
ok, err = util.FileSystem.Exists(ctx.Context, util.PathJoinSafe(modelsDir, downloadedModelName))
147+
ok, err = util.FileExists(util.PathJoinSafe(modelsDir, downloadedModelName))
148148
if err != nil {
149149
return err
150150
}
@@ -155,7 +155,7 @@ var runCommand = &cli.Command{
155155
if strings.Contains(modelPath, ":") {
156156
return fmt.Errorf("filters with : are currently not supported")
157157
}
158-
err = util.FileSystem.Create(context.Background(), modelsDir, os.ModePerm, true)
158+
err = util.CreateFile(modelsDir, true)
159159
if err != nil {
160160
return err
161161
}
@@ -229,7 +229,7 @@ var runCommand = &cli.Command{
229229

230230
if outputPath != "" {
231231
dest := util.PathJoinSafe(outputPath, fmt.Sprintf("result-%d.jsonl", i))
232-
writer, err = util.FileSystem.NewWriter(ctx.Context, dest, os.ModePerm)
232+
writer, err = util.NewFileWriter(dest, "application/json")
233233
if err != nil {
234234
return err
235235
}
@@ -258,7 +258,7 @@ var runCommand = &cli.Command{
258258

259259
// read inputs
260260

261-
exists, err := util.FileSystem.Exists(ctx.Context, inputPath)
261+
exists, err := util.FileExists(inputPath)
262262
if err != nil {
263263
return err
264264
}
@@ -276,7 +276,7 @@ var runCommand = &cli.Command{
276276
return true, nil
277277
}
278278

279-
err := util.FileSystem.Walk(ctx.Context, inputPath, fileWalker)
279+
err := util.WalkDir()(ctx.Context, inputPath, fileWalker)
280280
if err != nil {
281281
return err
282282
}

cmd/main_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package main
22

33
import (
4-
"context"
54
_ "embed"
65
"fmt"
76
"os"
@@ -145,7 +144,7 @@ func TestModelChain(t *testing.T) {
145144
// wipe the hugo folder
146145
userFolder, err := os.UserHomeDir()
147146
check(t, err)
148-
check(t, util.FileSystem.Delete(context.Background(), util.PathJoinSafe(userFolder, "hugot")))
147+
check(t, util.DeleteFile(util.PathJoinSafe(userFolder, "hugot")))
149148

150149
// try to download the model to hugo folder and run it
151150
args := append(baseArgs, "run",

cuda.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#--- dockerfile to test hugot ---
22

3-
ARG GO_VERSION=1.23.5
3+
ARG GO_VERSION=1.24.0
44
ARG ONNXRUNTIME_VERSION=1.20.1
55
ARG BUILD_PLATFORM=linux/amd64
66

datasets/dataset_xla.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/gomlx/gomlx/ml/train"
1414
"github.com/gomlx/gomlx/types/tensors"
15+
1516
"github.com/knights-analytics/hugot/pipelineBackends"
1617
"github.com/knights-analytics/hugot/pipelines"
1718
"github.com/knights-analytics/hugot/util"
@@ -63,9 +64,9 @@ func (s *SemanticSimilarityDataset) Validate() error {
6364
}
6465

6566
type SemanticSimilarityExample struct {
66-
Sentence1 string `json:"sentence1"`
67-
Sentence2 string `json:"sentence2"`
68-
Score float32 `json:"label"`
67+
Sentence1 *string `json:"sentence1"`
68+
Sentence2 *string `json:"sentence2"`
69+
Score *float32 `json:"score"`
6970
}
7071

7172
// NewSemanticSimilarityDataset creates a new SemanticSimilarityDataset.
@@ -95,7 +96,18 @@ func (s *SemanticSimilarityDataset) Reset() {
9596
fmt.Printf("completed epoch in %d batches of %d examples, resetting dataset\n", s.batchN, s.BatchSize)
9697
}
9798
s.batchN = 0
98-
s.reader = bufio.NewReader(s.sourceFile)
99+
if err := s.sourceFile.Close(); err != nil {
100+
panic(err)
101+
}
102+
103+
sourceReadCloser, err := util.OpenFile(s.TrainingPath) // TODO how to handle errors here
104+
if err != nil {
105+
panic(err)
106+
}
107+
s.sourceFile = sourceReadCloser
108+
109+
// restart the reader
110+
s.reader = bufio.NewReader(sourceReadCloser)
99111
}
100112

101113
func (s *SemanticSimilarityDataset) Yield() (spec any, inputs []*tensors.Tensor, labels []*tensors.Tensor, err error) {
@@ -123,9 +135,12 @@ func (s *SemanticSimilarityDataset) Yield() (spec any, inputs []*tensors.Tensor,
123135
if e := json.Unmarshal(lineBytes, &lineData); e != nil {
124136
return nil, nil, nil, fmt.Errorf("failed to parse JSON line: %w", e)
125137
}
126-
inputsLeft = append(inputsLeft, lineData.Sentence1)
127-
inputsRight = append(inputsRight, lineData.Sentence2)
128-
scores = append(scores, lineData.Score)
138+
if lineData.Sentence1 == nil || lineData.Sentence2 == nil || lineData.Score == nil {
139+
return nil, nil, nil, fmt.Errorf("missing required fields in JSON line")
140+
}
141+
inputsLeft = append(inputsLeft, *lineData.Sentence1)
142+
inputsRight = append(inputsRight, *lineData.Sentence2)
143+
scores = append(scores, *lineData.Score)
129144
batchCounter++
130145
}
131146

downloader.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
hfd "github.com/bodaay/HuggingFaceModelDownloader/hfdownloader"
1717
)
1818

19-
// DownloadOptions is a struct of options that can be passed to DownloadModel
19+
// DownloadOptions is a struct of options that can be passed to DownloadModel.
2020
type DownloadOptions struct {
2121
AuthToken string
2222
SkipSha bool

go.mod

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ require (
77
github.com/daulet/tokenizers v1.20.2
88
github.com/gomlx/exceptions v0.0.3
99
github.com/gomlx/gomlx v0.17.0
10+
github.com/gomlx/gopjrt v0.6.0
1011
github.com/gomlx/onnx-gomlx v0.2.0
1112
github.com/json-iterator/go v1.1.12
1213
github.com/mattn/go-isatty v0.0.20
1314
github.com/stretchr/testify v1.10.0
1415
github.com/urfave/cli/v2 v2.27.5
1516
github.com/viant/afs v1.25.1
16-
github.com/yalue/onnxruntime_go v1.16.0
17-
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c
17+
github.com/yalue/onnxruntime_go v1.17.0
18+
golang.org/x/exp v0.0.0-20250215185904-eff6e970281f
1819
)
1920

2021
require (
@@ -24,7 +25,6 @@ require (
2425
github.com/fatih/color v1.18.0 // indirect
2526
github.com/go-errors/errors v1.5.1 // indirect
2627
github.com/go-logr/logr v1.4.2 // indirect
27-
github.com/gomlx/gopjrt v0.6.0 // indirect
2828
github.com/google/uuid v1.6.0 // indirect
2929
github.com/mattn/go-colorable v0.1.14 // indirect
3030
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
@@ -35,9 +35,9 @@ require (
3535
github.com/russross/blackfriday/v2 v2.1.0 // indirect
3636
github.com/x448/float16 v0.8.4 // indirect
3737
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
38-
golang.org/x/crypto v0.32.0 // indirect
39-
golang.org/x/sys v0.29.0 // indirect
40-
google.golang.org/protobuf v1.36.4 // indirect
38+
golang.org/x/crypto v0.33.0 // indirect
39+
golang.org/x/sys v0.30.0 // indirect
40+
google.golang.org/protobuf v1.36.5 // indirect
4141
gopkg.in/yaml.v3 v3.0.1 // indirect
4242
k8s.io/klog/v2 v2.130.1 // indirect
4343
)

go.sum

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,19 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
6767
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
6868
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
6969
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
70-
github.com/yalue/onnxruntime_go v1.16.0 h1:YyHfuGsEy5AODMbXGePCGfIZ7DgeGW40gOu5TPDE2t4=
71-
github.com/yalue/onnxruntime_go v1.16.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
72-
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
73-
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
74-
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc=
75-
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
70+
github.com/yalue/onnxruntime_go v1.17.0 h1:nC8AFbmaq9E2gxtxutGPzK/LGCrtnnu7LTGl82YuQzw=
71+
github.com/yalue/onnxruntime_go v1.17.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
72+
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
73+
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
74+
golang.org/x/exp v0.0.0-20250215185904-eff6e970281f h1:oFMYAjX0867ZD2jcNiLBrI9BdpmEkvPyi5YrBGXbamg=
75+
golang.org/x/exp v0.0.0-20250215185904-eff6e970281f/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
7676
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
77-
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
78-
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
79-
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
80-
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
81-
google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM=
82-
google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
77+
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
78+
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
79+
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
80+
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
81+
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
82+
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
8383
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
8484
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
8585
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

hugot.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,15 @@ func (s *Session) Destroy() error {
240240
for _, model := range s.models {
241241
err = errors.Join(err, model.Destroy())
242242
}
243+
s.models = nil
244+
s.featureExtractionPipelines = nil
245+
s.tokenClassificationPipelines = nil
246+
s.textClassificationPipelines = nil
247+
s.zeroShotClassificationPipelines = nil
243248
err = errors.Join(
244249
s.options.Destroy(),
245250
s.environmentDestroy(),
246251
)
252+
s.options = nil
247253
return err
248254
}

hugot_ort.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
package hugot
44

55
import (
6-
"context"
76
"errors"
87
"fmt"
98

@@ -45,7 +44,7 @@ func (s *Session) initialiseORT() (bool, error) {
4544
o := s.options.ORTOptions
4645
// Set pre-initialisation options
4746
if o.LibraryPath != nil {
48-
ortPathExists, err := util.FileSystem.Exists(context.Background(), *o.LibraryPath)
47+
ortPathExists, err := util.FileExists(*o.LibraryPath)
4948
if err != nil {
5049
return false, err
5150
}

0 commit comments

Comments
 (0)