Skip to content

Commit b98e16c

Browse files
authored
Merge pull request #1036 from NexaAI/refactor/mengsheng/npu-systemp-prompt
refactor: remove depecated system prompt patch for npu
2 parents 9ee6a09 + de95faf commit b98e16c

File tree

8 files changed

+25
-85
lines changed

8 files changed

+25
-85
lines changed

runner/cmd/nexa-cli/infer.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,8 @@ func inferLLM(manifest *types.ModelManifest, quant string) error {
423423
PluginID: manifest.PluginId,
424424
DeviceID: manifest.DeviceId,
425425
Config: nexa_sdk.ModelConfig{
426-
NCtx: nctx,
427-
NGpuLayers: ngl,
428-
SystemPrompt: systemPrompt, // TODO: align npu
426+
NCtx: nctx,
427+
NGpuLayers: ngl,
429428
},
430429
})
431430
spin.Stop()
@@ -595,9 +594,8 @@ func inferVLM(manifest *types.ModelManifest, quant string) error {
595594
PluginID: manifest.PluginId,
596595
DeviceID: manifest.DeviceId,
597596
Config: nexa_sdk.ModelConfig{
598-
NCtx: nctx,
599-
NGpuLayers: ngl,
600-
SystemPrompt: systemPrompt,
597+
NCtx: nctx,
598+
NGpuLayers: ngl,
601599
},
602600
})
603601
spin.Stop()

runner/internal/types/model.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,6 @@ func (m ModelManifest) GetSize() int64 {
8686
type ModelParam struct {
8787
NCtx int32
8888
NGpuLayers int32
89-
90-
// npu only
91-
SystemPrompt string
9289
}
9390

9491
type DownloadInfo struct {

runner/nexa-sdk/common.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,5 +277,4 @@ type ModelConfig struct {
277277
NGpuLayers int32
278278
ChatTemplatePath string
279279
ChatTemplateContent string
280-
SystemPrompt string
281280
}

runner/nexa-sdk/llm.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,6 @@ func (lci LlmCreateInput) toCPtr() *C.ml_LlmCreateInput {
8282
if lci.Config.ChatTemplateContent != "" {
8383
cPtr.config.chat_template_content = C.CString(lci.Config.ChatTemplateContent)
8484
}
85-
// Add system prompt support
86-
if lci.Config.SystemPrompt != "" {
87-
cPtr.config.system_prompt = C.CString(lci.Config.SystemPrompt)
88-
}
8985

9086
return cPtr
9187
}
@@ -114,9 +110,6 @@ func freeLlmCreateInput(cPtr *C.ml_LlmCreateInput) {
114110
if cPtr.config.chat_template_content != nil {
115111
C.free(unsafe.Pointer(cPtr.config.chat_template_content))
116112
}
117-
if cPtr.config.system_prompt != nil {
118-
C.free(unsafe.Pointer(cPtr.config.system_prompt))
119-
}
120113

121114
// Free the main structure
122115
C.free(unsafe.Pointer(cPtr))

runner/nexa-sdk/vlm.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,6 @@ func (vci VlmCreateInput) toCPtr() *C.ml_VlmCreateInput {
7676
if vci.Config.ChatTemplateContent != "" {
7777
cPtr.config.chat_template_content = C.CString(vci.Config.ChatTemplateContent)
7878
}
79-
// Add system prompt support
80-
if vci.Config.SystemPrompt != "" {
81-
cPtr.config.system_prompt = C.CString(vci.Config.SystemPrompt)
82-
}
8379

8480
return cPtr
8581
}
@@ -110,9 +106,6 @@ func freeVlmCreateInput(cPtr *C.ml_VlmCreateInput) {
110106
if cPtr.config.chat_template_content != nil {
111107
C.free(unsafe.Pointer(cPtr.config.chat_template_content))
112108
}
113-
if cPtr.config.system_prompt != nil {
114-
C.free(unsafe.Pointer(cPtr.config.system_prompt))
115-
}
116109

117110
// Free the main structure
118111
C.free(unsafe.Pointer(cPtr))

runner/server/docs/swagger.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ components:
881881
stream_format:
882882
type: string
883883
enum: [wav, sse]
884-
description: "sse" returns Server-Sent Events stream; otherwise returns binary WAV
884+
description: '"sse" returns Server-Sent Events stream; otherwise returns binary WAV'
885885
speed:
886886
type: number
887887
format: float
@@ -905,7 +905,7 @@ components:
905905
description: The audio file to transcribe. Omit for warm up (returns null).
906906
stream:
907907
type: string
908-
description: "true" is not supported and returns 400
908+
description: '"true" is not supported and returns 400'
909909
language:
910910
type: string
911911
description: The language of the input audio

runner/server/handler/chat.go

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ func defaultChatCompletionRequest() ChatCompletionRequest {
8080
}
8181
}
8282

83+
func isWarmupRequest(param ChatCompletionRequest) bool {
84+
if len(param.Messages) == 0 {
85+
return true
86+
}
87+
if len(param.Messages) != 1 {
88+
return false
89+
}
90+
r := param.Messages[0].GetRole()
91+
return r != nil && *r == "system"
92+
}
93+
8394
func ChatCompletions(c *gin.Context) {
8495
param := defaultChatCompletionRequest()
8596
if err := c.ShouldBindJSON(&param); err != nil {
@@ -117,11 +128,8 @@ func ChatCompletions(c *gin.Context) {
117128
}
118129

119130
func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
120-
// Build message list for LLM template
121-
var systemPrompt string
122131
messages := make([]nexa_sdk.LlmChatMessage, 0, len(param.Messages))
123132
for _, msg := range param.Messages {
124-
// tool call message
125133
if toolCalls := msg.GetToolCalls(); len(toolCalls) > 0 {
126134
for _, tc := range toolCalls {
127135
messages = append(messages, nexa_sdk.LlmChatMessage{
@@ -133,7 +141,6 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
133141
continue
134142
}
135143

136-
// tool call response message
137144
if toolResp := msg.GetToolCallID(); toolResp != nil {
138145
messages = append(messages, nexa_sdk.LlmChatMessage{
139146
Role: nexa_sdk.LLMRole(*msg.GetRole()),
@@ -144,21 +151,13 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
144151

145152
switch content := msg.GetContent().AsAny().(type) {
146153
case *string:
147-
// NOTE: patch for npu
148-
if *msg.GetRole() == "system" {
149-
systemPrompt += *content
150-
}
151154
messages = append(messages, nexa_sdk.LlmChatMessage{
152155
Role: nexa_sdk.LLMRole(*msg.GetRole()),
153156
Content: *content,
154157
})
155158

156159
case *[]openai.ChatCompletionContentPartTextParam:
157160
for _, ct := range *content {
158-
// NOTE: patch for npu
159-
if *msg.GetRole() == "system" {
160-
systemPrompt += ct.Text
161-
}
162161
messages = append(messages, nexa_sdk.LlmChatMessage{
163162
Role: nexa_sdk.LLMRole(*msg.GetRole()),
164163
Content: ct.Text,
@@ -168,10 +167,6 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
168167
for _, ct := range *content {
169168
switch *ct.GetType() {
170169
case "text":
171-
// NOTE: patch for npu
172-
if *msg.GetRole() == "system" {
173-
systemPrompt += *ct.GetText()
174-
}
175170
messages = append(messages, nexa_sdk.LlmChatMessage{
176171
Role: nexa_sdk.LLMRole(*msg.GetRole()),
177172
Content: *ct.GetText(),
@@ -186,10 +181,6 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
186181
for _, ct := range *content {
187182
switch *ct.GetType() {
188183
case "text":
189-
// NOTE: patch for npu
190-
if *msg.GetRole() == "system" {
191-
systemPrompt += *ct.GetText()
192-
}
193184
messages = append(messages, nexa_sdk.LlmChatMessage{
194185
Role: nexa_sdk.LLMRole(*msg.GetRole()),
195186
Content: *ct.GetText(),
@@ -218,10 +209,9 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
218209

219210
samplerConfig := parseSamplerConfig(param)
220211

221-
// Get LLM instance
222212
p, err := service.KeepAliveGet[nexa_sdk.LLM](
223213
string(param.Model),
224-
types.ModelParam{NCtx: param.NCtx, NGpuLayers: param.Ngl, SystemPrompt: systemPrompt},
214+
types.ModelParam{NCtx: param.NCtx, NGpuLayers: param.Ngl},
225215
c.GetHeader("Nexa-KeepCache") != "true",
226216
)
227217
if errors.Is(err, os.ErrNotExist) {
@@ -231,13 +221,11 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
231221
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error(), "code": nexa_sdk.SDKErrorCode(err)})
232222
return
233223
}
234-
// Empty request for warm up
235-
if len(param.Messages) == 0 || (systemPrompt != "" && len(param.Messages) <= 1) {
224+
if isWarmupRequest(param) {
236225
c.JSON(http.StatusOK, nil)
237226
return
238227
}
239228

240-
// Format prompt using chat template
241229
formatted, err := p.ApplyChatTemplate(nexa_sdk.LlmApplyChatTemplateInput{
242230
Messages: messages,
243231
Tools: tools,
@@ -421,11 +409,8 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
421409
}
422410

423411
func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
424-
// Build message list for VLM template
425-
var systemPrompt string
426412
messages := make([]nexa_sdk.VlmChatMessage, 0, len(param.Messages))
427413
for _, msg := range param.Messages {
428-
// tool call message
429414
if toolCalls := msg.GetToolCalls(); len(toolCalls) > 0 {
430415
contents := make([]nexa_sdk.VlmContent, 0, len(toolCalls))
431416
for _, tc := range toolCalls {
@@ -442,7 +427,6 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
442427
continue
443428
}
444429

445-
// tool call response message
446430
if toolResp := msg.GetToolCallID(); toolResp != nil {
447431
messages = append(messages, nexa_sdk.VlmChatMessage{
448432
Role: nexa_sdk.VlmRole(*msg.GetRole()),
@@ -456,9 +440,6 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
456440

457441
switch content := msg.GetContent().AsAny().(type) {
458442
case *string:
459-
if *msg.GetRole() == "system" {
460-
systemPrompt += *content
461-
}
462443
messages = append(messages, nexa_sdk.VlmChatMessage{
463444
Role: nexa_sdk.VlmRole(*msg.GetRole()),
464445
Contents: []nexa_sdk.VlmContent{
@@ -468,32 +449,22 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
468449

469450
case *[]openai.ChatCompletionContentPartTextParam:
470451
contents := make([]nexa_sdk.VlmContent, 0, len(*content))
471-
472452
for _, ct := range *content {
473-
if *msg.GetRole() == "system" {
474-
systemPrompt += ct.Text
475-
}
476453
contents = append(contents, nexa_sdk.VlmContent{
477454
Type: nexa_sdk.VlmContentTypeText,
478455
Text: ct.Text,
479456
})
480457
}
481-
482458
messages = append(messages, nexa_sdk.VlmChatMessage{
483459
Role: nexa_sdk.VlmRole(*msg.GetRole()),
484460
Contents: contents,
485461
})
486462

487463
case *[]openai.ChatCompletionContentPartUnionParam:
488464
contents := make([]nexa_sdk.VlmContent, 0, len(*content))
489-
490465
for _, ct := range *content {
491466
switch *ct.GetType() {
492467
case "text":
493-
// NOTE: patch for npu
494-
if *msg.GetRole() == "system" {
495-
systemPrompt += *ct.GetText()
496-
}
497468
contents = append(contents, nexa_sdk.VlmContent{
498469
Type: nexa_sdk.VlmContentTypeText,
499470
Text: *ct.GetText(),
@@ -528,22 +499,16 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
528499
return
529500
}
530501
}
531-
532502
messages = append(messages, nexa_sdk.VlmChatMessage{
533503
Role: nexa_sdk.VlmRole(*msg.GetRole()),
534504
Contents: contents,
535505
})
536506

537507
case *[]openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion:
538508
contents := make([]nexa_sdk.VlmContent, 0, len(*content))
539-
540509
for _, ct := range *content {
541510
switch *ct.GetType() {
542511
case "text":
543-
// NOTE: patch for npu
544-
if *msg.GetRole() == "system" {
545-
systemPrompt += *ct.GetText()
546-
}
547512
contents = append(contents, nexa_sdk.VlmContent{
548513
Type: nexa_sdk.VlmContentTypeText,
549514
Text: *ct.GetText(),
@@ -577,10 +542,9 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
577542

578543
samplerConfig := parseSamplerConfig(param)
579544

580-
// Get VLM instance
581545
p, err := service.KeepAliveGet[nexa_sdk.VLM](
582546
string(param.Model),
583-
types.ModelParam{NCtx: param.NCtx, NGpuLayers: param.Ngl, SystemPrompt: systemPrompt},
547+
types.ModelParam{NCtx: param.NCtx, NGpuLayers: param.Ngl},
584548
c.GetHeader("Nexa-KeepCache") != "true",
585549
)
586550
if errors.Is(err, os.ErrNotExist) {
@@ -590,9 +554,7 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
590554
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error(), "code": nexa_sdk.SDKErrorCode(err)})
591555
return
592556
}
593-
594-
// Empty request for warm up, just reset model state
595-
if len(param.Messages) == 0 || (systemPrompt != "" && len(param.Messages) <= 1) {
557+
if isWarmupRequest(param) {
596558
c.JSON(http.StatusOK, nil)
597559
return
598560
}

runner/server/service/keepalive.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,8 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
169169
ModelName: manifest.ModelName,
170170
ModelPath: modelfile,
171171
Config: nexa_sdk.ModelConfig{
172-
NCtx: param.NCtx,
173-
NGpuLayers: param.NGpuLayers,
174-
SystemPrompt: param.SystemPrompt,
172+
NCtx: param.NCtx,
173+
NGpuLayers: param.NGpuLayers,
175174
},
176175
PluginID: manifest.PluginId,
177176
DeviceID: manifest.DeviceId,
@@ -191,9 +190,8 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
191190
MmprojPath: mmproj,
192191
TokenizerPath: tokenizer,
193192
Config: nexa_sdk.ModelConfig{
194-
NCtx: param.NCtx,
195-
NGpuLayers: param.NGpuLayers,
196-
SystemPrompt: param.SystemPrompt,
193+
NCtx: param.NCtx,
194+
NGpuLayers: param.NGpuLayers,
197195
},
198196
PluginID: manifest.PluginId,
199197
DeviceID: manifest.DeviceId,

0 commit comments

Comments
 (0)