diff --git a/common/consts.go b/common/consts.go index d380b0d..9491285 100644 --- a/common/consts.go +++ b/common/consts.go @@ -82,3 +82,10 @@ const ( DATABASE_MEM0_API_KEY = "DATABASE_MEM0_API_KEY" DATABASE_MEM0_REGION = "DATABASE_MEM0_REGION" ) + +// Prompt pilot +const ( + AGENTPILOT_API_URL = "AGENTPILOT_API_URL" + AGENTPILOT_API_KEY = "AGENTPILOT_API_KEY" + AGENTPILOT_WORKSPACE_ID = "AGENTPILOT_WORKSPACE_ID" +) diff --git a/common/defaults.go b/common/defaults.go index 06657d3..76a056e 100644 --- a/common/defaults.go +++ b/common/defaults.go @@ -68,3 +68,7 @@ const ( DEFAULT_AGENTKIT_TOOL_REGION = "cn-beijing" DEFAULT_AGENTKIT_TOOL_SERVICE_CODE = "agentkit" ) + +const ( + DEFAULT_AGENTPILOT_API_URL = "https://prompt-pilot.cn-beijing.volces.com" +) diff --git a/configs/configs.go b/configs/configs.go index ba28f0e..6d08b59 100644 --- a/configs/configs.go +++ b/configs/configs.go @@ -83,6 +83,7 @@ func SetupVeADKConfig() error { } globalConfig.Model.MapEnvToConfig() globalConfig.Tool.MapEnvToConfig() + globalConfig.PromptPilot.MapEnvToConfig() globalConfig.LOGGING.MapEnvToConfig() globalConfig.Database.MapEnvToConfig() globalConfig.Volcengine.MapEnvToConfig() diff --git a/configs/prompt_pilot.go b/configs/prompt_pilot.go index 96bf14f..01b558a 100644 --- a/configs/prompt_pilot.go +++ b/configs/prompt_pilot.go @@ -14,6 +14,19 @@ package configs +import ( + "github.com/volcengine/veadk-go/common" + "github.com/volcengine/veadk-go/utils" +) + type PromptPilotConfig struct { - // 根据实际字段补充 + Url string `yaml:"url"` + ApiKey string `yaml:"api_key"` + WorkspaceId string `yaml:"workspace_id"` +} + +func (v *PromptPilotConfig) MapEnvToConfig() { + v.Url = utils.GetEnvWithDefault(common.AGENTPILOT_API_URL) + v.ApiKey = utils.GetEnvWithDefault(common.AGENTPILOT_API_KEY) + v.WorkspaceId = utils.GetEnvWithDefault(common.AGENTPILOT_WORKSPACE_ID) } diff --git a/go.mod b/go.mod index f50ece9..5a4994d 100644 --- a/go.mod +++ b/go.mod @@ -81,4 +81,4 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect rsc.io/omap v1.2.0 // indirect rsc.io/ordered v1.1.1 // indirect -) +) \ No newline at end of file diff --git a/integrations/ve_prompt_pilot/prompt_pilot.go b/integrations/ve_prompt_pilot/prompt_pilot.go new file mode 100644 index 0000000..0dce7b8 --- /dev/null +++ b/integrations/ve_prompt_pilot/prompt_pilot.go @@ -0,0 +1,267 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ve_prompt_pilot + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "iter" + "log" + "net/http" + "strings" + "time" + + "github.com/google/uuid" + "github.com/volcengine/veadk-go/common" + "github.com/volcengine/veadk-go/configs" + "github.com/volcengine/veadk-go/prompts" + "github.com/volcengine/veadk-go/utils" +) + +const ( + defaultOptimizeModel = "doubao-seed-1.6-251015" + defaultHttpTimeout = 120 +) + +var ( + ErrUrlValidationFailed = errors.New("AGENTPILOT_API_URL environment variable is not set") + ErrApiKeyValidationFailed = errors.New("AGENTPILOT_API_KEY environment variable is not set") + ErrWorkspaceIdValidationFailed = errors.New("AGENTPILOT_WORKSPACE_ID environment variable is not set") +) + +// VePromptPilot handles prompt optimization interactions. +type VePromptPilot struct { + url string + apiKey string + workspaceID string + httpClient *http.Client +} + +// New creates a new VePromptPilot instance. +func New(opts ...func(*VePromptPilot)) *VePromptPilot { + p := &VePromptPilot{ + url: fmt.Sprintf("%s/agent-pilot?Version=2024-01-01&Action=GeneratePromptStream", utils.GetEnvWithDefault(common.AGENTPILOT_API_URL, configs.GetGlobalConfig().PromptPilot.Url, common.DEFAULT_AGENTPILOT_API_URL)), + apiKey: utils.GetEnvWithDefault(common.AGENTPILOT_API_KEY, configs.GetGlobalConfig().PromptPilot.ApiKey), + workspaceID: utils.GetEnvWithDefault(common.AGENTPILOT_WORKSPACE_ID, configs.GetGlobalConfig().PromptPilot.WorkspaceId), + httpClient: &http.Client{ + Timeout: time.Second * defaultHttpTimeout, + }, + } + + for _, opt := range opts { + opt(p) + } + return p +} + +// WithUrl sets the url for the pilot. +func WithUrl(url string) func(*VePromptPilot) { + return func(p *VePromptPilot) { + p.url = url + } +} + +// WithAPIKey sets the API key for the pilot. +func WithAPIKey(apiKey string) func(*VePromptPilot) { + return func(p *VePromptPilot) { + p.apiKey = apiKey + } +} + +// WithWorkspaceID sets the workspace ID for the pilot. +func WithWorkspaceID(workspaceID string) func(*VePromptPilot) { + return func(p *VePromptPilot) { + p.workspaceID = workspaceID + } +} + +// WithHTTPClient sets the HTTP client for the pilot. +func WithHTTPClient(client *http.Client) func(*VePromptPilot) { + return func(p *VePromptPilot) { + p.httpClient = client + } +} + +// generatePromptRequest represents the JSON body for the API request. +type generatePromptRequest struct { + RequestID string `json:"request_id"` + WorkspaceID string `json:"workspace_id"` + TaskType string `json:"task_type"` + Rule string `json:"rule"` + CurrentPrompt string `json:"current_prompt,omitempty"` + ModelName string `json:"model_name"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` +} + +func (p *VePromptPilot) Valid() error { + if p.url == "" { + return ErrUrlValidationFailed + } + if p.apiKey == "" { + return ErrApiKeyValidationFailed + } + if p.workspaceID == "" { + return ErrWorkspaceIdValidationFailed + } + return nil +} + +// Optimize optimizes the prompts for the given agents using the specified feedback and model. +func (p *VePromptPilot) Optimize(agentInfo *prompts.AgentInfo, feedback string, modelName string) (string, error) { + if err := p.Valid(); err != nil { + return "", err + } + + if modelName == "" { + modelName = defaultOptimizeModel + } + var finalPrompt string + var taskDescription string + var err error + + if feedback == "" { + log.Println("Optimizing prompt without feedback.") + taskDescription, err = prompts.RenderPromptWithTemplate(agentInfo) + } else { + log.Printf("Optimizing prompt with feedback: %s\n", feedback) + taskDescription, err = prompts.RenderPromptFeedbackWithTemplate(agentInfo, feedback) + } + + if err != nil { + return "", fmt.Errorf("rendering optimization task description: %w", err) + } + + //TaskType Enum + //"DEFAULT" # single turn task + //"MULTIMODAL" # visual reasoning single turn task + //"DIALOG" # multi turn dialog + reqBody := &generatePromptRequest{ + RequestID: uuid.New().String(), + WorkspaceID: p.workspaceID, + TaskType: "DIALOG", + Rule: taskDescription, + CurrentPrompt: agentInfo.Instruction, + ModelName: modelName, + Temperature: 1.0, + TopP: 0.7, + } + + var builder strings.Builder + var usageTotal int + for event, err := range p.generateStream(context.Background(), reqBody) { + if err != nil { + return "", fmt.Errorf("generateStream error: %w", err) + } + if event.Event == "message" { + builder.WriteString(event.Data.Content) + } else if event.Event == "usage" { + usageTotal = event.Data.Usage.TotalTokens + } else { + eventStr, _ := json.Marshal(event) + log.Printf("Unexpected event: %s\n", string(eventStr)) + } + } + + finalPrompt = strings.ReplaceAll(builder.String(), "\\n", "\n") + + log.Printf("Optimized prompt is -----\n%s\n-----\n", finalPrompt) + + if usageTotal > 0 { + log.Printf("Token usage: %d", usageTotal) + } else { + log.Println("[Warn]No usage data.") + } + + return finalPrompt, nil +} + +func (p *VePromptPilot) sendRequest(ctx context.Context, reqBody *generatePromptRequest) (*http.Response, error) { + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.url, bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+p.apiKey) + req.Header.Set("Content-Type", "application/json") + + httpResp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + if httpResp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(httpResp.Body) + if err = httpResp.Body.Close(); err != nil { + return nil, fmt.Errorf("API failed to close response body: %w", err) + } + return nil, fmt.Errorf("API error (status %d): %s", httpResp.StatusCode, string(body)) + } + + return httpResp, nil +} + +func (p *VePromptPilot) generateStream(ctx context.Context, req *generatePromptRequest) iter.Seq2[*GeneratePromptStreamResponseChunk, error] { + return func(yield func(*GeneratePromptStreamResponseChunk, error) bool) { + httpResp, err := p.sendRequest(ctx, req) + if err != nil { + yield(nil, err) + return + } + defer func() { + _ = httpResp.Body.Close() + }() + + scanner := bufio.NewScanner(httpResp.Body) + + var promptChunk *GeneratePromptStreamResponseChunk + for scanner.Scan() { + line := scanner.Text() + decodedLine := strings.TrimSpace(line) + promptChunk = parseEventStreamLine(decodedLine, promptChunk) + if promptChunk != nil { + hasContent := promptChunk.Data != nil && promptChunk.Data.Content != "" + hasUsage := promptChunk.Data != nil && promptChunk.Data.Usage != nil + hasError := promptChunk.Data != nil && promptChunk.Data.Error != "" + + if hasContent || hasUsage { + yieldData := promptChunk + promptChunk = nil + yield(yieldData, nil) + continue + } else if hasError { + yield(nil, fmt.Errorf("prompt pilot generate error: %s", promptChunk.Data.Error)) + continue + } else { + continue + } + } + } + + if err := scanner.Err(); err != nil { + yield(nil, fmt.Errorf("stream error: %w", err)) + return + } + } +} diff --git a/integrations/ve_prompt_pilot/prompt_pilot_test.go b/integrations/ve_prompt_pilot/prompt_pilot_test.go new file mode 100644 index 0000000..47c1fc6 --- /dev/null +++ b/integrations/ve_prompt_pilot/prompt_pilot_test.go @@ -0,0 +1,161 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ve_prompt_pilot + +import ( + "bytes" + "fmt" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/volcengine/veadk-go/common" + "github.com/volcengine/veadk-go/prompts" + "github.com/volcengine/veadk-go/utils" +) + +func TestNew(t *testing.T) { + vePromptPilot := New() + if utils.GetEnvWithDefault(common.AGENTPILOT_WORKSPACE_ID) == "" || utils.GetEnvWithDefault(common.AGENTPILOT_API_KEY) == "" { + t.Skip() + } + prompt, err := vePromptPilot.Optimize(&prompts.AgentInfo{ + Name: "weather_agent", + Model: defaultOptimizeModel, + Instruction: "你是一个MBTI人格分析大师,负责根据用户提供的个人信息分析用户的MBTI人格。", + }, + "", defaultOptimizeModel) + if err != nil { + fmt.Println("error:", err) + return + } + + fmt.Println(prompt) +} + +type mockRoundTripper struct { + roundTripFunc func(req *http.Request) (*http.Response, error) +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.roundTripFunc(req) +} + +func TestVePromptPilot_Optimize_Mock(t *testing.T) { + agentInfo := &prompts.AgentInfo{ + Name: "test_agent", + Instruction: "Initial instruction", + } + + t.Run("Success", func(t *testing.T) { + mockRespBody := `event: message +data: "Optimized " +event: message +data: "instruction" +event: usage +data: {"total_tokens": 50} +` + client := &http.Client{ + Transport: &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, "POST", req.Method) + assert.Contains(t, req.URL.String(), "/agent-pilot") + assert.Equal(t, "Bearer test-api-key", req.Header.Get("Authorization")) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(mockRespBody)), + Header: make(http.Header), + }, nil + }, + }, + } + + pilot := New( + WithUrl("http://mock-url/agent-pilot"), + WithAPIKey("test-api-key"), + WithWorkspaceID("test-workspace"), + WithHTTPClient(client), + ) + + prompt, err := pilot.Optimize(agentInfo, "Make it better", "test-model") + assert.NoError(t, err) + assert.Equal(t, "Optimized instruction", prompt) + }) + + t.Run("APIError", func(t *testing.T) { + client := &http.Client{ + Transport: &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(bytes.NewBufferString("Bad Request")), + Header: make(http.Header), + }, nil + }, + }, + } + + pilot := New( + WithUrl("http://mock-url"), + WithAPIKey("test-api-key"), + WithWorkspaceID("test-workspace"), + WithHTTPClient(client), + ) + + prompt, err := pilot.Optimize(agentInfo, "", "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "API error (status 400)") + assert.Empty(t, prompt) + }) + + t.Run("StreamError", func(t *testing.T) { + mockRespBody := `event: error +data: Something went wrong +` + client := &http.Client{ + Transport: &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(mockRespBody)), + Header: make(http.Header), + }, nil + }, + }, + } + + pilot := New( + WithUrl("http://mock-url"), + WithAPIKey("test-api-key"), + WithWorkspaceID("test-workspace"), + WithHTTPClient(client), + ) + + prompt, err := pilot.Optimize(agentInfo, "", "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "prompt pilot generate error: Something went wrong") + assert.Empty(t, prompt) + }) + + t.Run("ValidationError", func(t *testing.T) { + pilot := New( + WithUrl(""), // Invalid URL + ) + _, err := pilot.Optimize(agentInfo, "", "") + assert.Equal(t, ErrUrlValidationFailed, err) + }) +} diff --git a/integrations/ve_prompt_pilot/utils.go b/integrations/ve_prompt_pilot/utils.go new file mode 100644 index 0000000..02ee3fe --- /dev/null +++ b/integrations/ve_prompt_pilot/utils.go @@ -0,0 +1,94 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ve_prompt_pilot + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +type Usage struct { + TotalTokens int `json:"total_tokens"` +} + +type GeneratePromptChunk struct { + Content string `json:"content,omitempty"` + Usage *Usage `json:"usage,omitempty"` + Error string `json:"error,omitempty"` +} + +type GeneratePromptStreamResponseChunk struct { + Event string `json:"event"` + Data *GeneratePromptChunk `json:"data,omitempty"` +} + +var ( + dataMessageRegex = regexp.MustCompile(`^data: "(?P.*)"$`) + dataGenericRegex = regexp.MustCompile(`^data: (?P.*)$`) + eventRegex = regexp.MustCompile(`^event: (?P[^:]+)$`) +) + +func parseEventStreamLine(line string, promptChunk *GeneratePromptStreamResponseChunk) *GeneratePromptStreamResponseChunk { + if promptChunk != nil && promptChunk.Event == "message" && promptChunk.Data.Content == "" { + if strings.HasPrefix(line, "data: ") { + match := dataMessageRegex.FindStringSubmatch(line) + if len(match) > 1 { + content := match[1] + var decodedContent string + jsonStr := fmt.Sprintf(`"%s"`, content) + if err := json.Unmarshal([]byte(jsonStr), &decodedContent); err == nil { + promptChunk.Data.Content = decodedContent + return promptChunk + } + } + } + } else if promptChunk != nil && promptChunk.Event == "usage" && promptChunk.Data.Usage == nil { + if strings.HasPrefix(line, "data: ") { + match := dataGenericRegex.FindStringSubmatch(line) + if len(match) > 1 { + dataStr := match[1] + var usage *Usage + // usage 是 JSON 对象 + if err := json.Unmarshal([]byte(dataStr), &usage); err == nil { + promptChunk.Data.Usage = usage + return promptChunk + } + } + } + } else if promptChunk != nil && promptChunk.Event == "error" && promptChunk.Data.Error == "" { + if strings.HasPrefix(line, "data: ") { + match := dataGenericRegex.FindStringSubmatch(line) + if len(match) > 1 { + // error 直接作为字符串处理 + promptChunk.Data.Error = match[1] + return promptChunk + } + } + } else { + // 检查是否是新事件的开始 + if strings.HasPrefix(line, "event:") { + match := eventRegex.FindStringSubmatch(line) + if len(match) > 1 { + return &GeneratePromptStreamResponseChunk{ + Event: strings.TrimSpace(match[1]), + Data: &GeneratePromptChunk{}, + } + } + } + } + return nil +} diff --git a/integrations/ve_prompt_pilot/utils_test.go b/integrations/ve_prompt_pilot/utils_test.go new file mode 100644 index 0000000..f27bd01 --- /dev/null +++ b/integrations/ve_prompt_pilot/utils_test.go @@ -0,0 +1,259 @@ +package ve_prompt_pilot + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseEventStreamLine_MockData(t *testing.T) { + // Mock data provided by user + mockLines := []string{ + "event: message", + `data: "# Role \nYou are an MBTI personality analysis master,"`, + "", + "event: message", + `data: " specializing in analyzing users' MBTI personality types based on the personal"`, + "", + "event: message", + `data: " information they provide."`, + "", + "event: message", + `data: " \n\n# Task Requirements \n## Core Workflow \n1."`, + "", + "event: message", + `data: " **Collect Information**: First,"`, + "", + "event: message", + `data: " listen carefully to the personal information users share (such as daily"`, + "", + "event: message", + `data: " behavior,"`, + "", + "event: message", + `data: " decision-making habits,"`, + "", + "event: message", + `data: " social preferences,"`, + "", + "event: message", + `data: " work/study styles,"`, + "", + "event: message", + `data: " emotional responses,"`, + "", + "event: message", + `data: " etc.). If the information provided is insufficient to make an accurate"`, + "", + "event: message", + `data: " judgment,"`, + "", + "event: message", + `data: " proactively ask targeted follow-up questions (e.g., \"Can you tell me how"`, + "", + "event: message", + `data: " you usually spend your weekends?\" or \"When making an important decision"`, + "", + "event: message", + `data: ","`, + "", + "event: message", + `data: " do you prefer to rely on logical analysis or intuitive feelings?"`, + "", + "event: message", + `data: "\"). \n2."`, + "", + "event: message", + `data: " **Analyze Personality**: Based on the collected information,"`, + "", + "event: message", + `data: " combine the core dimensions of MBTI (Extraversion/Introversion,"`, + "", + "event: message", + `data: " Sensing/Intuition,"`, + "", + "event: message", + `data: " Thinking/Feeling,"`, + "", + "event: message", + `data: " Judging/Perceiving) to conduct a professional and rational analysis."`, + "", + "event: message", + `data: " \n3."`, + "", + "event: message", + `data: " **Present Results**: Clearly state the inferred MBTI type and explain"`, + "", + "event: message", + `data: " the reasoning in simple,"`, + "", + "event: message", + `data: " easy-to-understand language—linking the analysis directly to the specific"`, + "", + "event: message", + `data: " details the user provided."`, + "", + "event: message", + `data: " \n\n## Communication Style \n- Maintain a professional yet approachable"`, + "", + "event: message", + `data: " tone;"`, + "", + "event: message", + `data: " avoid using overly obscure psychological jargon."`, + "", + "event: message", + `data: " \n- Ensure the analysis is objective and evidence-based,"`, + "", + "event: message", + `data: " not relying on subjective assumptions beyond the user’s stated information"`, + "", + "event: message", + `data: "."`, + "", + "event: message", + `data: " \n\n# Output Guidelines \nEach analysis should include: \n1."`, + "", + "event: message", + `data: " The inferred MBTI type."`, + "", + "event: message", + `data: " \n2."`, + "", + "event: message", + `data: " A breakdown of how the user’s behaviors align with each dimension of the"`, + "", + "event: message", + `data: " type."`, + "", + "event: message", + `data: " \n3."`, + "", + "event: message", + `data: " A brief,"`, + "", + "event: message", + `data: " relatable summary of the type’s typical traits to help the user understand"`, + "", + "event: message", + `data: " better."`, + "", + "event: message", + `data: " \n\nExample: If a user says,"`, + "", + "event: message", + `data: " \"I love planning every detail of my trips in advance,"`, + "", + "event: message", + `data: " prefer working alone on projects,"`, + "", + "event: message", + `data: " and often make decisions based on whether they feel fair to others,\" your"`, + "", + "event: message", + `data: " analysis might include: \n- Inferred type: ISFJ \n- Reasoning: \"Prefer"`, + "", + "event: message", + `data: " planning details\" aligns with Judging;"`, + "", + "event: message", + `data: " \"work alone\" suggests Introversion;"`, + "", + "event: message", + `data: " \"focus on fairness to others\" reflects Feeling;"`, + "", + "event: message", + `data: " \"attention to specific trip details\" indicates Sensing."`, + "", + "event: message", + `data: " \n- Summary: ISFJs are often caring,"`, + "", + "event: message", + `data: " detail-oriented,"`, + "", + "event: message", + `data: " and value stability in their lives."`, + "", + "event: usage", + `data: {"total_tokens": 3807}`, + "", + "event: usage", + `data: {"total_tokens": 3807}`, + "", + } + + var currentChunk *GeneratePromptStreamResponseChunk + var fullContent strings.Builder + var lastUsage *Usage + + for _, line := range mockLines { + result := parseEventStreamLine(line, currentChunk) + if result != nil { + currentChunk = result + + // If we have content, append it + if currentChunk.Event == "message" && currentChunk.Data != nil && currentChunk.Data.Content != "" { + fullContent.WriteString(currentChunk.Data.Content) + // Reset content to avoid double counting if we process the same chunk object again (though parseEventStreamLine creates new chunks for events) + // Actually, parseEventStreamLine updates the *same* chunk object when parsing data. + // However, since we're iterating line by line, and the mock data has event -> data -> empty -> event pattern. + // Each "event: message" creates a NEW chunk. + // Then "data: ..." fills it. + // So we should capture the content when it's filled. + } + + if currentChunk.Event == "usage" && currentChunk.Data != nil && currentChunk.Data.Usage != nil { + lastUsage = currentChunk.Data.Usage + } + } + } + + // Verify the assembled content contains expected parts + expectedParts := []string{ + "# Role", + "MBTI personality analysis master", + "Task Requirements", + "Collect Information", + "Analyze Personality", + "Present Results", + "Communication Style", + "Output Guidelines", + "Example: If a user says", + "ISFJ", + } + + gotContent := fullContent.String() + for _, part := range expectedParts { + assert.Contains(t, gotContent, part, "Content should contain: "+part) + } + + // Verify usage + assert.NotNil(t, lastUsage) + if lastUsage != nil { + assert.Equal(t, 3807, lastUsage.TotalTokens) + } +} + +func TestParseEventStreamLine_Error(t *testing.T) { + mockLines := []string{ + "event: error", + "data: Something went wrong", + "", + } + + var currentChunk *GeneratePromptStreamResponseChunk + var errorMsg string + + for _, line := range mockLines { + result := parseEventStreamLine(line, currentChunk) + if result != nil { + currentChunk = result + if currentChunk.Event == "error" && currentChunk.Data != nil && currentChunk.Data.Error != "" { + errorMsg = currentChunk.Data.Error + } + } + } + + assert.Equal(t, "Something went wrong", errorMsg) +} diff --git a/prompts/prompt_optimization.go b/prompts/prompt_optimization.go new file mode 100644 index 0000000..527b2b5 --- /dev/null +++ b/prompts/prompt_optimization.go @@ -0,0 +1,120 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prompts + +import ( + "bytes" + "text/template" +) + +const prompt = `Please help me to optimize the following agent prompt: +{{.OriginalPrompt}} + + +The following information is your references: + +name: {{.Agent.Name}} +model: {{.Agent.Model}} +description: {{.Agent.Description}} + + + +{{range .Tools}} + +name: {{.Name}} +type: {{.Type}} +description: {{.Description}} +arguments: {{.Arguments}} + +{{end}} + + +Please note that in your optimized prompt: +- the above referenced information is not necessary. For example, the tools list of agent is not necessary in the optimized prompt, because it maybe too long. You should use the tool information to optimize the original prompt rather than simply add tool list in prompt. +- The max length of optimized prompt should be less 4096 tokens. +` + +const promptWithFeedback = `After you optimization, my current prompt is: +{{.Prompt}} + +I did some evaluations with the optimized prompt, and the feedback is: {{.Feedback}} + +Please continue to optimize the prompt based on the feedback. +` + +type AgentInfo struct { + Name string + Model string + Description string + Instruction string + Tools []*ToolInfo +} + +// ToolInfo 结构体定义 +type ToolInfo struct { + Name string + Type string + Description string + Arguments string +} + +func RenderPromptFeedbackWithTemplate(agent *AgentInfo, feedback string) (string, error) { + tmpl, err := template.New("promptWithFeedback").Parse(promptWithFeedback) + if err != nil { + return "", err + } + + context := map[string]interface{}{ + "Prompt": agent.Instruction, + "Feedback": feedback, + } + + // 执行模板渲染 + var buf bytes.Buffer + err = tmpl.Execute(&buf, context) + if err != nil { + return "", err + } + + return buf.String(), nil +} + +func RenderPromptWithTemplate(agent *AgentInfo) (string, error) { + // 解析模板 + tmpl, err := template.New("prompt").Parse(prompt) + if err != nil { + return "", err + } + + // 准备上下文数据 + context := map[string]interface{}{ + "OriginalPrompt": agent.Instruction, + "Agent": map[string]string{ + "Name": agent.Name, + "Model": agent.Model, + "Description": agent.Description, + }, + "Tools": agent.Tools, + } + + // 执行模板渲染 + var buf bytes.Buffer + err = tmpl.Execute(&buf, context) + if err != nil { + return "", err + } + + return buf.String(), nil +}