Skip to content

Commit 78628f6

Browse files
feat: delete pipelines function
1 parent 6474f10 commit 78628f6

File tree

5 files changed

+119
-1
lines changed

5 files changed

+119
-1
lines changed

hugot.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ type Session struct {
2222
}
2323

2424
func newSession(runtime string, opts ...options.WithOption) (*Session, error) {
25-
2625
parsedOptions := options.Defaults()
2726
parsedOptions.Runtime = runtime
2827
// Collect options into a struct, so they can be applied in the correct order later
@@ -173,6 +172,8 @@ func InitializePipeline[T pipelineBackends.Pipeline](p T, pipelineConfig pipelin
173172
default:
174173
return pipeline, name, fmt.Errorf("not implemented")
175174
}
175+
176+
model.Pipelines[name] = pipeline
176177
return pipeline, name, nil
177178
}
178179

@@ -209,6 +210,59 @@ func GetPipeline[T pipelineBackends.Pipeline](s *Session, name string) (T, error
209210
}
210211
}
211212

213+
func ClosePipeline[T pipelineBackends.Pipeline](s *Session, name string) error {
214+
var pipeline T
215+
switch any(pipeline).(type) {
216+
case *pipelines.TokenClassificationPipeline:
217+
p, ok := s.tokenClassificationPipelines[name]
218+
if ok {
219+
model := p.Model
220+
delete(s.tokenClassificationPipelines, name)
221+
delete(model.Pipelines, name)
222+
if len(model.Pipelines) == 0 {
223+
delete(s.models, model.Path)
224+
return model.Destroy()
225+
}
226+
}
227+
case *pipelines.TextClassificationPipeline:
228+
p, ok := s.textClassificationPipelines[name]
229+
if ok {
230+
model := p.Model
231+
delete(s.textClassificationPipelines, name)
232+
delete(model.Pipelines, name)
233+
if len(model.Pipelines) == 0 {
234+
delete(s.models, model.Path)
235+
return model.Destroy()
236+
}
237+
}
238+
case *pipelines.FeatureExtractionPipeline:
239+
p, ok := s.featureExtractionPipelines[name]
240+
if ok {
241+
model := p.Model
242+
delete(s.featureExtractionPipelines, name)
243+
delete(model.Pipelines, name)
244+
if len(model.Pipelines) == 0 {
245+
delete(s.models, model.Path)
246+
return model.Destroy()
247+
}
248+
}
249+
case *pipelines.ZeroShotClassificationPipeline:
250+
p, ok := s.zeroShotClassificationPipelines[name]
251+
if ok {
252+
model := p.Model
253+
delete(s.zeroShotClassificationPipelines, name)
254+
delete(model.Pipelines, name)
255+
if len(model.Pipelines) == 0 {
256+
delete(s.models, model.Path)
257+
return model.Destroy()
258+
}
259+
}
260+
default:
261+
return errors.New("pipeline type not supported")
262+
}
263+
return nil
264+
}
265+
212266
type pipelineNotFoundError struct {
213267
pipelineName string
214268
}

hugot_ort_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,17 @@ func TestNoSameNamePipelineORT(t *testing.T) {
241241
noSameNamePipeline(t, session)
242242
}
243243

244+
func TestClosePipelineORT(t *testing.T) {
245+
opts := []options.WithOption{options.WithOnnxLibraryPath(onnxRuntimeSharedLibrary)}
246+
session, err := NewORTSession(opts...)
247+
check(t, err)
248+
defer func(session *Session) {
249+
destroyErr := session.Destroy()
250+
check(t, destroyErr)
251+
}(session)
252+
destroyPipelines(t, session)
253+
}
254+
244255
// Thread safety
245256

246257
func TestThreadSafetyORT(t *testing.T) {

hugot_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,47 @@ func noSameNamePipeline(t *testing.T, session *Session) {
827827
assert.Error(t, err3)
828828
}
829829

830+
// destroy pipelines
831+
832+
func destroyPipelines(t *testing.T, session *Session) {
833+
t.Helper()
834+
835+
modelPath := "./models/KnightsAnalytics_distilbert-NER"
836+
configSimple := TokenClassificationConfig{
837+
ModelPath: modelPath,
838+
Name: "testClosePipeline",
839+
Options: []TokenClassificationOption{
840+
pipelines.WithSimpleAggregation(),
841+
pipelines.WithIgnoreLabels([]string{"O"}),
842+
},
843+
}
844+
_, err2 := NewPipeline(session, configSimple)
845+
if err2 != nil {
846+
t.FailNow()
847+
}
848+
849+
if len(session.models) != 1 {
850+
t.Fatal("Session should have 1 model")
851+
}
852+
853+
for _, model := range session.models {
854+
if _, ok := model.Pipelines["testClosePipeline"]; !ok {
855+
t.Fatal("Pipeline alias was not added to the model")
856+
}
857+
}
858+
859+
if err := ClosePipeline[*pipelines.TokenClassificationPipeline](session, "testClosePipeline"); err != nil {
860+
t.Fatal(err)
861+
}
862+
863+
if len(session.models) != 0 {
864+
t.Fatal("Session should have 0 models")
865+
}
866+
if len(session.tokenClassificationPipelines) != 0 {
867+
t.Fatal("Session should have 0 token classification pipelines")
868+
}
869+
}
870+
830871
// Thread safety
831872

832873
func threadSafety(t *testing.T, session *Session) {

hugot_xla_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,16 @@ func TestNoSameNamePipelineXLA(t *testing.T) {
209209
noSameNamePipeline(t, session)
210210
}
211211

212+
func TestDestroyPipelineXLA(t *testing.T) {
213+
session, err := NewXLASession()
214+
check(t, err)
215+
defer func(session *Session) {
216+
destroyErr := session.Destroy()
217+
check(t, destroyErr)
218+
}(session)
219+
destroyPipelines(t, session)
220+
}
221+
212222
// Thread safety
213223

214224
func TestThreadSafetyXLA(t *testing.T) {

pipelineBackends/model.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ type Model struct {
1616
InputsMeta []InputOutputInfo
1717
OutputsMeta []InputOutputInfo
1818
Destroy func() error
19+
Pipelines map[string]Pipeline
1920
}
2021

2122
func ReshapeOutput(input *[]float32, meta InputOutputInfo, paddingMask [][]bool, sequenceLength int) OutputArray {
@@ -88,6 +89,7 @@ func LoadModel(path string, onnxFilename string, options *options.Options) (*Mod
8889
model := &Model{
8990
Path: path,
9091
OnnxFilename: onnxFilename,
92+
Pipelines: make(map[string]Pipeline),
9193
}
9294

9395
err := LoadOnnxModelBytes(model)

0 commit comments

Comments
 (0)