Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,10 @@ const (
DATABASE_MEM0_API_KEY = "DATABASE_MEM0_API_KEY"
DATABASE_MEM0_REGION = "DATABASE_MEM0_REGION"
)

// Prompt pilot
const (
AGENTPILOT_API_URL = "AGENTPILOT_API_URL"
AGENTPILOT_API_KEY = "AGENTPILOT_API_KEY"
AGENTPILOT_WORKSPACE_ID = "AGENTPILOT_WORKSPACE_ID"
)
4 changes: 4 additions & 0 deletions common/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ const (
DEFAULT_AGENTKIT_TOOL_REGION = "cn-beijing"
DEFAULT_AGENTKIT_TOOL_SERVICE_CODE = "agentkit"
)

const (
DEFAULT_AGENTPILOT_API_URL = "https://prompt-pilot.cn-beijing.volces.com"
)
1 change: 1 addition & 0 deletions configs/configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func SetupVeADKConfig() error {
}
globalConfig.Model.MapEnvToConfig()
globalConfig.Tool.MapEnvToConfig()
globalConfig.PromptPilot.MapEnvToConfig()
globalConfig.LOGGING.MapEnvToConfig()
globalConfig.Database.MapEnvToConfig()
globalConfig.Volcengine.MapEnvToConfig()
Expand Down
15 changes: 14 additions & 1 deletion configs/prompt_pilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@

package configs

import (
"github.com/volcengine/veadk-go/common"
"github.com/volcengine/veadk-go/utils"
)

type PromptPilotConfig struct {
// 根据实际字段补充
Url string `yaml:"url"`
ApiKey string `yaml:"api_key"`
WorkspaceId string `yaml:"workspace_id"`
}

func (v *PromptPilotConfig) MapEnvToConfig() {
v.Url = utils.GetEnvWithDefault(common.AGENTPILOT_API_URL)
v.ApiKey = utils.GetEnvWithDefault(common.AGENTPILOT_API_KEY)
v.WorkspaceId = utils.GetEnvWithDefault(common.AGENTPILOT_WORKSPACE_ID)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ require (
gopkg.in/yaml.v2 v2.4.0 // indirect
rsc.io/omap v1.2.0 // indirect
rsc.io/ordered v1.1.1 // indirect
)
)
267 changes: 267 additions & 0 deletions integrations/ve_prompt_pilot/prompt_pilot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ve_prompt_pilot

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"iter"
"log"
"net/http"
"strings"
"time"

"github.com/google/uuid"
"github.com/volcengine/veadk-go/common"
"github.com/volcengine/veadk-go/configs"
"github.com/volcengine/veadk-go/prompts"
"github.com/volcengine/veadk-go/utils"
)

const (
defaultOptimizeModel = "doubao-seed-1.6-251015"
defaultHttpTimeout = 120
)

var (
ErrUrlValidationFailed = errors.New("AGENTPILOT_API_URL environment variable is not set")
ErrApiKeyValidationFailed = errors.New("AGENTPILOT_API_KEY environment variable is not set")
ErrWorkspaceIdValidationFailed = errors.New("AGENTPILOT_WORKSPACE_ID environment variable is not set")
)

// VePromptPilot handles prompt optimization interactions.
type VePromptPilot struct {
url string
apiKey string
workspaceID string
httpClient *http.Client
}

// New creates a new VePromptPilot instance.
func New(opts ...func(*VePromptPilot)) *VePromptPilot {
p := &VePromptPilot{
url: fmt.Sprintf("%s/agent-pilot?Version=2024-01-01&Action=GeneratePromptStream", utils.GetEnvWithDefault(common.AGENTPILOT_API_URL, configs.GetGlobalConfig().PromptPilot.Url, common.DEFAULT_AGENTPILOT_API_URL)),
apiKey: utils.GetEnvWithDefault(common.AGENTPILOT_API_KEY, configs.GetGlobalConfig().PromptPilot.ApiKey),
workspaceID: utils.GetEnvWithDefault(common.AGENTPILOT_WORKSPACE_ID, configs.GetGlobalConfig().PromptPilot.WorkspaceId),
httpClient: &http.Client{
Timeout: time.Second * defaultHttpTimeout,
},
}

for _, opt := range opts {
opt(p)
}
return p
}

// WithUrl sets the url for the pilot.
func WithUrl(url string) func(*VePromptPilot) {
return func(p *VePromptPilot) {
p.url = url
}
}

// WithAPIKey sets the API key for the pilot.
func WithAPIKey(apiKey string) func(*VePromptPilot) {
return func(p *VePromptPilot) {
p.apiKey = apiKey
}
}

// WithWorkspaceID sets the workspace ID for the pilot.
func WithWorkspaceID(workspaceID string) func(*VePromptPilot) {
return func(p *VePromptPilot) {
p.workspaceID = workspaceID
}
}

// WithHTTPClient sets the HTTP client for the pilot.
func WithHTTPClient(client *http.Client) func(*VePromptPilot) {
return func(p *VePromptPilot) {
p.httpClient = client
}
}

// generatePromptRequest represents the JSON body for the API request.
type generatePromptRequest struct {
RequestID string `json:"request_id"`
WorkspaceID string `json:"workspace_id"`
TaskType string `json:"task_type"`
Rule string `json:"rule"`
CurrentPrompt string `json:"current_prompt,omitempty"`
ModelName string `json:"model_name"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
}

func (p *VePromptPilot) Valid() error {
if p.url == "" {
return ErrUrlValidationFailed
}
if p.apiKey == "" {
return ErrApiKeyValidationFailed
}
if p.workspaceID == "" {
return ErrWorkspaceIdValidationFailed
}
return nil
}

// Optimize optimizes the prompts for the given agents using the specified feedback and model.
func (p *VePromptPilot) Optimize(agentInfo *prompts.AgentInfo, feedback string, modelName string) (string, error) {
if err := p.Valid(); err != nil {
return "", err
}

if modelName == "" {
modelName = defaultOptimizeModel
}
var finalPrompt string
var taskDescription string
var err error

if feedback == "" {
log.Println("Optimizing prompt without feedback.")
taskDescription, err = prompts.RenderPromptWithTemplate(agentInfo)
} else {
log.Printf("Optimizing prompt with feedback: %s\n", feedback)
taskDescription, err = prompts.RenderPromptFeedbackWithTemplate(agentInfo, feedback)
}

if err != nil {
return "", fmt.Errorf("rendering optimization task description: %w", err)
}

//TaskType Enum
//"DEFAULT" # single turn task
//"MULTIMODAL" # visual reasoning single turn task
//"DIALOG" # multi turn dialog
reqBody := &generatePromptRequest{
RequestID: uuid.New().String(),
WorkspaceID: p.workspaceID,
TaskType: "DIALOG",
Rule: taskDescription,
CurrentPrompt: agentInfo.Instruction,
ModelName: modelName,
Temperature: 1.0,
TopP: 0.7,
}

var builder strings.Builder
var usageTotal int
for event, err := range p.generateStream(context.Background(), reqBody) {
if err != nil {
return "", fmt.Errorf("generateStream error: %w", err)
}
if event.Event == "message" {
builder.WriteString(event.Data.Content)
} else if event.Event == "usage" {
usageTotal = event.Data.Usage.TotalTokens
} else {
eventStr, _ := json.Marshal(event)
log.Printf("Unexpected event: %s\n", string(eventStr))
}
}

finalPrompt = strings.ReplaceAll(builder.String(), "\\n", "\n")

log.Printf("Optimized prompt is -----\n%s\n-----\n", finalPrompt)

if usageTotal > 0 {
log.Printf("Token usage: %d", usageTotal)
} else {
log.Println("[Warn]No usage data.")
}

return finalPrompt, nil
}

func (p *VePromptPilot) sendRequest(ctx context.Context, reqBody *generatePromptRequest) (*http.Response, error) {
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", p.url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("Content-Type", "application/json")

httpResp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
if err = httpResp.Body.Close(); err != nil {
return nil, fmt.Errorf("API failed to close response body: %w", err)
}
return nil, fmt.Errorf("API error (status %d): %s", httpResp.StatusCode, string(body))
}

return httpResp, nil
}

func (p *VePromptPilot) generateStream(ctx context.Context, req *generatePromptRequest) iter.Seq2[*GeneratePromptStreamResponseChunk, error] {
return func(yield func(*GeneratePromptStreamResponseChunk, error) bool) {
httpResp, err := p.sendRequest(ctx, req)
if err != nil {
yield(nil, err)
return
}
defer func() {
_ = httpResp.Body.Close()
}()

scanner := bufio.NewScanner(httpResp.Body)

var promptChunk *GeneratePromptStreamResponseChunk
for scanner.Scan() {
line := scanner.Text()
decodedLine := strings.TrimSpace(line)
promptChunk = parseEventStreamLine(decodedLine, promptChunk)
if promptChunk != nil {
hasContent := promptChunk.Data != nil && promptChunk.Data.Content != ""
hasUsage := promptChunk.Data != nil && promptChunk.Data.Usage != nil
hasError := promptChunk.Data != nil && promptChunk.Data.Error != ""

if hasContent || hasUsage {
yieldData := promptChunk
promptChunk = nil
yield(yieldData, nil)
continue
} else if hasError {
yield(nil, fmt.Errorf("prompt pilot generate error: %s", promptChunk.Data.Error))
continue
} else {
continue
}
}
}

if err := scanner.Err(); err != nil {
yield(nil, fmt.Errorf("stream error: %w", err))
return
}
}
}
Loading