Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions runner/cmd/nexa-cli/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,8 @@ func inferLLM(manifest *types.ModelManifest, quant string) error {
PluginID: manifest.PluginId,
DeviceID: manifest.DeviceId,
Config: nexa_sdk.ModelConfig{
NCtx: nctx,
NGpuLayers: ngl,
SystemPrompt: systemPrompt, // TODO: align npu
NCtx: nctx,
NGpuLayers: ngl,
},
})
spin.Stop()
Expand Down Expand Up @@ -595,9 +594,8 @@ func inferVLM(manifest *types.ModelManifest, quant string) error {
PluginID: manifest.PluginId,
DeviceID: manifest.DeviceId,
Config: nexa_sdk.ModelConfig{
NCtx: nctx,
NGpuLayers: ngl,
SystemPrompt: systemPrompt,
NCtx: nctx,
NGpuLayers: ngl,
},
})
spin.Stop()
Expand Down
3 changes: 0 additions & 3 deletions runner/internal/types/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ func (m ModelManifest) GetSize() int64 {
type ModelParam struct {
NCtx int32
NGpuLayers int32

// npu only
SystemPrompt string
}

type DownloadInfo struct {
Expand Down
1 change: 0 additions & 1 deletion runner/nexa-sdk/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,4 @@ type ModelConfig struct {
NGpuLayers int32
ChatTemplatePath string
ChatTemplateContent string
SystemPrompt string
}
7 changes: 0 additions & 7 deletions runner/nexa-sdk/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ func (lci LlmCreateInput) toCPtr() *C.ml_LlmCreateInput {
if lci.Config.ChatTemplateContent != "" {
cPtr.config.chat_template_content = C.CString(lci.Config.ChatTemplateContent)
}
// Add system prompt support
if lci.Config.SystemPrompt != "" {
cPtr.config.system_prompt = C.CString(lci.Config.SystemPrompt)
}

return cPtr
}
Expand Down Expand Up @@ -114,9 +110,6 @@ func freeLlmCreateInput(cPtr *C.ml_LlmCreateInput) {
if cPtr.config.chat_template_content != nil {
C.free(unsafe.Pointer(cPtr.config.chat_template_content))
}
if cPtr.config.system_prompt != nil {
C.free(unsafe.Pointer(cPtr.config.system_prompt))
}

// Free the main structure
C.free(unsafe.Pointer(cPtr))
Expand Down
7 changes: 0 additions & 7 deletions runner/nexa-sdk/vlm.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ func (vci VlmCreateInput) toCPtr() *C.ml_VlmCreateInput {
if vci.Config.ChatTemplateContent != "" {
cPtr.config.chat_template_content = C.CString(vci.Config.ChatTemplateContent)
}
// Add system prompt support
if vci.Config.SystemPrompt != "" {
cPtr.config.system_prompt = C.CString(vci.Config.SystemPrompt)
}

return cPtr
}
Expand Down Expand Up @@ -110,9 +106,6 @@ func freeVlmCreateInput(cPtr *C.ml_VlmCreateInput) {
if cPtr.config.chat_template_content != nil {
C.free(unsafe.Pointer(cPtr.config.chat_template_content))
}
if cPtr.config.system_prompt != nil {
C.free(unsafe.Pointer(cPtr.config.system_prompt))
}

// Free the main structure
C.free(unsafe.Pointer(cPtr))
Expand Down
4 changes: 2 additions & 2 deletions runner/server/docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ components:
stream_format:
type: string
enum: [wav, sse]
description: "sse" returns Server-Sent Events stream; otherwise returns binary WAV
description: '"sse" returns Server-Sent Events stream; otherwise returns binary WAV'
speed:
type: number
format: float
Expand All @@ -905,7 +905,7 @@ components:
description: The audio file to transcribe. Omit for warm up (returns null).
stream:
type: string
description: "true" is not supported and returns 400
description: '"true" is not supported and returns 400'
language:
type: string
description: The language of the input audio
Expand Down
65 changes: 12 additions & 53 deletions runner/server/handler/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ func defaultChatCompletionRequest() ChatCompletionRequest {
}
}

func onlySystemMessage(param ChatCompletionRequest) bool {
if len(param.Messages) != 1 {
return false
}
r := param.Messages[0].GetRole()
return r != nil && *r == "system"
}

func ChatCompletions(c *gin.Context) {
param := defaultChatCompletionRequest()
if err := c.ShouldBindJSON(&param); err != nil {
Expand Down Expand Up @@ -117,11 +125,8 @@ func ChatCompletions(c *gin.Context) {
}

func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
// Build message list for LLM template
var systemPrompt string
messages := make([]nexa_sdk.LlmChatMessage, 0, len(param.Messages))
for _, msg := range param.Messages {
// tool call message
if toolCalls := msg.GetToolCalls(); len(toolCalls) > 0 {
for _, tc := range toolCalls {
messages = append(messages, nexa_sdk.LlmChatMessage{
Expand All @@ -133,7 +138,6 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
continue
}

// tool call response message
if toolResp := msg.GetToolCallID(); toolResp != nil {
messages = append(messages, nexa_sdk.LlmChatMessage{
Role: nexa_sdk.LLMRole(*msg.GetRole()),
Expand All @@ -144,21 +148,13 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {

switch content := msg.GetContent().AsAny().(type) {
case *string:
// NOTE: patch for npu
if *msg.GetRole() == "system" {
systemPrompt += *content
}
messages = append(messages, nexa_sdk.LlmChatMessage{
Role: nexa_sdk.LLMRole(*msg.GetRole()),
Content: *content,
})

case *[]openai.ChatCompletionContentPartTextParam:
for _, ct := range *content {
// NOTE: patch for npu
if *msg.GetRole() == "system" {
systemPrompt += ct.Text
}
messages = append(messages, nexa_sdk.LlmChatMessage{
Role: nexa_sdk.LLMRole(*msg.GetRole()),
Content: ct.Text,
Expand All @@ -168,10 +164,6 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
for _, ct := range *content {
switch *ct.GetType() {
case "text":
// NOTE: patch for npu
if *msg.GetRole() == "system" {
systemPrompt += *ct.GetText()
}
messages = append(messages, nexa_sdk.LlmChatMessage{
Role: nexa_sdk.LLMRole(*msg.GetRole()),
Content: *ct.GetText(),
Expand All @@ -186,10 +178,6 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
for _, ct := range *content {
switch *ct.GetType() {
case "text":
// NOTE: patch for npu
if *msg.GetRole() == "system" {
systemPrompt += *ct.GetText()
}
messages = append(messages, nexa_sdk.LlmChatMessage{
Role: nexa_sdk.LLMRole(*msg.GetRole()),
Content: *ct.GetText(),
Expand Down Expand Up @@ -218,10 +206,9 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {

samplerConfig := parseSamplerConfig(param)

// Get LLM instance
p, err := service.KeepAliveGet[nexa_sdk.LLM](
string(param.Model),
types.ModelParam{NCtx: param.NCtx, NGpuLayers: param.Ngl, SystemPrompt: systemPrompt},
types.ModelParam{NCtx: param.NCtx, NGpuLayers: param.Ngl},
c.GetHeader("Nexa-KeepCache") != "true",
)
if errors.Is(err, os.ErrNotExist) {
Expand All @@ -231,13 +218,11 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error(), "code": nexa_sdk.SDKErrorCode(err)})
return
}
// Empty request for warm up
if len(param.Messages) == 0 || (systemPrompt != "" && len(param.Messages) <= 1) {
if len(param.Messages) == 0 || onlySystemMessage(param) {
c.JSON(http.StatusOK, nil)
return
}

// Format prompt using chat template
formatted, err := p.ApplyChatTemplate(nexa_sdk.LlmApplyChatTemplateInput{
Messages: messages,
Tools: tools,
Expand Down Expand Up @@ -421,11 +406,8 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
}

func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
// Build message list for VLM template
var systemPrompt string
messages := make([]nexa_sdk.VlmChatMessage, 0, len(param.Messages))
for _, msg := range param.Messages {
// tool call message
if toolCalls := msg.GetToolCalls(); len(toolCalls) > 0 {
contents := make([]nexa_sdk.VlmContent, 0, len(toolCalls))
for _, tc := range toolCalls {
Expand All @@ -442,7 +424,6 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
continue
}

// tool call response message
if toolResp := msg.GetToolCallID(); toolResp != nil {
messages = append(messages, nexa_sdk.VlmChatMessage{
Role: nexa_sdk.VlmRole(*msg.GetRole()),
Expand All @@ -456,9 +437,6 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {

switch content := msg.GetContent().AsAny().(type) {
case *string:
if *msg.GetRole() == "system" {
systemPrompt += *content
}
messages = append(messages, nexa_sdk.VlmChatMessage{
Role: nexa_sdk.VlmRole(*msg.GetRole()),
Contents: []nexa_sdk.VlmContent{
Expand All @@ -468,32 +446,22 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {

case *[]openai.ChatCompletionContentPartTextParam:
contents := make([]nexa_sdk.VlmContent, 0, len(*content))

for _, ct := range *content {
if *msg.GetRole() == "system" {
systemPrompt += ct.Text
}
contents = append(contents, nexa_sdk.VlmContent{
Type: nexa_sdk.VlmContentTypeText,
Text: ct.Text,
})
}

messages = append(messages, nexa_sdk.VlmChatMessage{
Role: nexa_sdk.VlmRole(*msg.GetRole()),
Contents: contents,
})

case *[]openai.ChatCompletionContentPartUnionParam:
contents := make([]nexa_sdk.VlmContent, 0, len(*content))

for _, ct := range *content {
switch *ct.GetType() {
case "text":
// NOTE: patch for npu
if *msg.GetRole() == "system" {
systemPrompt += *ct.GetText()
}
contents = append(contents, nexa_sdk.VlmContent{
Type: nexa_sdk.VlmContentTypeText,
Text: *ct.GetText(),
Expand Down Expand Up @@ -528,22 +496,16 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
return
}
}

messages = append(messages, nexa_sdk.VlmChatMessage{
Role: nexa_sdk.VlmRole(*msg.GetRole()),
Contents: contents,
})

case *[]openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion:
contents := make([]nexa_sdk.VlmContent, 0, len(*content))

for _, ct := range *content {
switch *ct.GetType() {
case "text":
// NOTE: patch for npu
if *msg.GetRole() == "system" {
systemPrompt += *ct.GetText()
}
contents = append(contents, nexa_sdk.VlmContent{
Type: nexa_sdk.VlmContentTypeText,
Text: *ct.GetText(),
Expand Down Expand Up @@ -577,10 +539,9 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {

samplerConfig := parseSamplerConfig(param)

// Get VLM instance
p, err := service.KeepAliveGet[nexa_sdk.VLM](
string(param.Model),
types.ModelParam{NCtx: param.NCtx, NGpuLayers: param.Ngl, SystemPrompt: systemPrompt},
types.ModelParam{NCtx: param.NCtx, NGpuLayers: param.Ngl},
c.GetHeader("Nexa-KeepCache") != "true",
)
if errors.Is(err, os.ErrNotExist) {
Expand All @@ -590,9 +551,7 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error(), "code": nexa_sdk.SDKErrorCode(err)})
return
}

// Empty request for warm up, just reset model state
if len(param.Messages) == 0 || (systemPrompt != "" && len(param.Messages) <= 1) {
if len(param.Messages) == 0 || onlySystemMessage(param) {
c.JSON(http.StatusOK, nil)
return
}
Expand Down
10 changes: 4 additions & 6 deletions runner/server/service/keepalive.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
ModelName: manifest.ModelName,
ModelPath: modelfile,
Config: nexa_sdk.ModelConfig{
NCtx: param.NCtx,
NGpuLayers: param.NGpuLayers,
SystemPrompt: param.SystemPrompt,
NCtx: param.NCtx,
NGpuLayers: param.NGpuLayers,
},
PluginID: manifest.PluginId,
DeviceID: manifest.DeviceId,
Expand All @@ -191,9 +190,8 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
MmprojPath: mmproj,
TokenizerPath: tokenizer,
Config: nexa_sdk.ModelConfig{
NCtx: param.NCtx,
NGpuLayers: param.NGpuLayers,
SystemPrompt: param.SystemPrompt,
NCtx: param.NCtx,
NGpuLayers: param.NGpuLayers,
},
PluginID: manifest.PluginId,
DeviceID: manifest.DeviceId,
Expand Down