From 76f0e8019ee619a056167b77db7cd9ae08a2ccbf Mon Sep 17 00:00:00 2001 From: "shijinyu.7" Date: Tue, 27 Jan 2026 11:12:14 +0800 Subject: [PATCH] feat: support plugins and llm shield --- .gitignore | 1 + apps/a2a_app/app.go | 1 + apps/agentkit_server_app/app.go | 1 + apps/basic_app.go | 2 + apps/simple_app/app.go | 1 + auth/veauth/utils.go | 20 ++ common/consts.go | 8 + common/defaults.go | 6 + configs/configs.go | 1 + configs/tool.go | 16 + examples/plugins/agent.go | 146 +++++++++ go.mod | 4 +- go.sum | 4 +- tool/builtin_tools/llm_shield.go | 407 ++++++++++++++++++++++++++ tool/builtin_tools/llm_shield_test.go | 46 +++ 15 files changed, 660 insertions(+), 4 deletions(-) create mode 100644 examples/plugins/agent.go create mode 100644 tool/builtin_tools/llm_shield.go create mode 100644 tool/builtin_tools/llm_shield_test.go diff --git a/.gitignore b/.gitignore index 315b71c..147b027 100644 --- a/.gitignore +++ b/.gitignore @@ -22,5 +22,6 @@ go.work .idea/ +.trae agent/.env examples/quickstart/config.yaml \ No newline at end of file diff --git a/apps/a2a_app/app.go b/apps/a2a_app/app.go index 1f17981..d2d914f 100644 --- a/apps/a2a_app/app.go +++ b/apps/a2a_app/app.go @@ -96,6 +96,7 @@ func (a *agentkitA2AServerApp) SetupRouters(router *mux.Router, config *apps.Run SessionService: config.SessionService, ArtifactService: config.ArtifactService, MemoryService: config.MemoryService, + PluginConfig: config.PluginConfig, }, }) reqHandler := a2asrv.NewHandler(executor, config.A2AOptions...) diff --git a/apps/agentkit_server_app/app.go b/apps/agentkit_server_app/app.go index 21f42bd..1ea1d5a 100644 --- a/apps/agentkit_server_app/app.go +++ b/apps/agentkit_server_app/app.go @@ -93,6 +93,7 @@ func (a *agentkitServerApp) SetupRouters(router *mux.Router, config *apps.RunCon MemoryService: config.MemoryService, AgentLoader: config.AgentLoader, A2AOptions: config.A2AOptions, + PluginConfig: config.PluginConfig, } // setup webui routers diff --git a/apps/basic_app.go b/apps/basic_app.go index 9bca649..2b520a8 100644 --- a/apps/basic_app.go +++ b/apps/basic_app.go @@ -24,6 +24,7 @@ import ( "google.golang.org/adk/agent" "google.golang.org/adk/artifact" "google.golang.org/adk/memory" + "google.golang.org/adk/runner" "google.golang.org/adk/session" ) @@ -33,6 +34,7 @@ type RunConfig struct { MemoryService memory.Service AgentLoader agent.Loader A2AOptions []a2asrv.RequestHandlerOption + PluginConfig runner.PluginConfig } type ApiConfig struct { diff --git a/apps/simple_app/app.go b/apps/simple_app/app.go index 341605f..c0fee59 100644 --- a/apps/simple_app/app.go +++ b/apps/simple_app/app.go @@ -78,6 +78,7 @@ func (app *agentkitSimpleApp) SetupRouters(router *mux.Router, config *apps.RunC SessionService: sessionService, ArtifactService: config.ArtifactService, MemoryService: config.MemoryService, + PluginConfig: config.PluginConfig, }) if err != nil { return fmt.Errorf("new runner error: %w", err) diff --git a/auth/veauth/utils.go b/auth/veauth/utils.go index a28e8a1..2988ff8 100644 --- a/auth/veauth/utils.go +++ b/auth/veauth/utils.go @@ -16,10 +16,13 @@ package veauth import ( "encoding/json" + "log" "os" "strings" "github.com/volcengine/veadk-go/common" + "github.com/volcengine/veadk-go/configs" + "github.com/volcengine/veadk-go/utils" ) type VeIAMCredential struct { @@ -50,3 +53,20 @@ func RefreshAKSK(accessKey string, secretKey string) (VeIAMCredential, error) { } return GetCredentialFromVeFaaSIAM() } + +func GetAuthInfo() (ak, sk, sessionToken string) { + ak = utils.GetEnvWithDefault(common.VOLCENGINE_ACCESS_KEY, configs.GetGlobalConfig().Volcengine.AK) + sk = utils.GetEnvWithDefault(common.VOLCENGINE_SECRET_KEY, configs.GetGlobalConfig().Volcengine.SK) + + if strings.TrimSpace(ak) == "" || strings.TrimSpace(sk) == "" { + iam, err := GetCredentialFromVeFaaSIAM() + if err != nil { + log.Printf("GetAuthInfo error: %s\n", err.Error()) + } else { + ak = iam.AccessKeyID + sk = iam.SecretAccessKey + sessionToken = iam.SessionToken + } + } + return +} diff --git a/common/consts.go b/common/consts.go index 9491285..1bb8a60 100644 --- a/common/consts.go +++ b/common/consts.go @@ -89,3 +89,11 @@ const ( AGENTPILOT_API_KEY = "AGENTPILOT_API_KEY" AGENTPILOT_WORKSPACE_ID = "AGENTPILOT_WORKSPACE_ID" ) + +// LLM Shield +const ( + TOOL_LLM_SHIELD_URL = "TOOL_LLM_SHIELD_URL" + TOOL_LLM_SHIELD_REGION = "TOOL_LLM_SHIELD_REGION" + TOOL_LLM_SHIELD_APP_ID = "TOOL_LLM_SHIELD_APP_ID" + TOOL_LLM_SHIELD_API_KEY = "TOOL_LLM_SHIELD_API_KEY" +) diff --git a/common/defaults.go b/common/defaults.go index 76a056e..c9fc529 100644 --- a/common/defaults.go +++ b/common/defaults.go @@ -69,6 +69,12 @@ const ( DEFAULT_AGENTKIT_TOOL_SERVICE_CODE = "agentkit" ) +// prompt pilot const ( DEFAULT_AGENTPILOT_API_URL = "https://prompt-pilot.cn-beijing.volces.com" ) + +// LLM Shield +const ( + DEFAULT_LLM_SHIELD_REGION = "cn-beijing" +) diff --git a/configs/configs.go b/configs/configs.go index 6d08b59..885c287 100644 --- a/configs/configs.go +++ b/configs/configs.go @@ -69,6 +69,7 @@ func SetupVeADKConfig() error { Tool: &BuiltinToolConfigs{ MCPRouter: &MCPRouter{}, RunCode: &RunCode{}, + LLMShield: &LLMShield{}, }, PromptPilot: &PromptPilotConfig{}, TlsConfig: &TLSConfig{}, diff --git a/configs/tool.go b/configs/tool.go index 91271cd..3f66335 100644 --- a/configs/tool.go +++ b/configs/tool.go @@ -22,11 +22,13 @@ import ( type BuiltinToolConfigs struct { MCPRouter *MCPRouter `yaml:"mcp_router"` RunCode *RunCode `yaml:"run_code"` + LLMShield *LLMShield `yaml:"llm_shield"` } func (b *BuiltinToolConfigs) MapEnvToConfig() { b.MCPRouter.MapEnvToConfig() b.RunCode.MapEnvToConfig() + b.LLMShield.MapEnvToConfig() } type MCPRouter struct { @@ -52,3 +54,17 @@ func (r *RunCode) MapEnvToConfig() { r.ServiceCode = utils.GetEnvWithDefault(common.AGENTKIT_TOOL_SERVICE_CODE) r.Region = utils.GetEnvWithDefault(common.AGENTKIT_TOOL_REGION) } + +type LLMShield struct { + Url string `yaml:"url"` + Region string `yaml:"region"` + AppId string `yaml:"app_id"` + ApiKey string `yaml:"api_key"` +} + +func (r *LLMShield) MapEnvToConfig() { + r.Url = utils.GetEnvWithDefault(common.TOOL_LLM_SHIELD_URL) + r.Region = utils.GetEnvWithDefault(common.TOOL_LLM_SHIELD_REGION) + r.ApiKey = utils.GetEnvWithDefault(common.TOOL_LLM_SHIELD_API_KEY) + r.AppId = utils.GetEnvWithDefault(common.TOOL_LLM_SHIELD_APP_ID) +} diff --git a/examples/plugins/agent.go b/examples/plugins/agent.go new file mode 100644 index 0000000..1a54ba3 --- /dev/null +++ b/examples/plugins/agent.go @@ -0,0 +1,146 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + + veagent "github.com/volcengine/veadk-go/agent/llmagent" + "github.com/volcengine/veadk-go/agent/workflowagents/sequentialagent" + "github.com/volcengine/veadk-go/apps" + "github.com/volcengine/veadk-go/apps/agentkit_server_app" + "github.com/volcengine/veadk-go/tool/builtin_tools" + "github.com/volcengine/veadk-go/tool/builtin_tools/web_search" + "github.com/volcengine/veadk-go/utils" + "google.golang.org/adk/agent" + "google.golang.org/adk/agent/llmagent" + "google.golang.org/adk/model" + "google.golang.org/adk/plugin" + "google.golang.org/adk/runner" + "google.golang.org/adk/session" + "google.golang.org/adk/tool" +) + +func main() { + ctx := context.Background() + + webSearch, err := web_search.NewWebSearchTool(&web_search.Config{}) + if err != nil { + fmt.Printf("NewWebSearchTool failed: %v", err) + return + } + + greetingAgent, err := veagent.New(&veagent.Config{ + Config: llmagent.Config{ + Name: "greeting_agent", + Description: "A friendly agent that greets the user.", + Instruction: "Greet the user warmly.", + Tools: []tool.Tool{ + webSearch, + }, + }, + ModelExtraConfig: map[string]any{ + "extra_body": map[string]any{ + "thinking": map[string]string{ + "type": "disabled", + }, + }, + }, + }) + if err != nil { + fmt.Printf("NewLLMAgent greetingAgent failed: %v", err) + return + } + + goodbyeAgent, err := veagent.New(&veagent.Config{ + Config: llmagent.Config{ + Name: "goodbye_agent", + Description: "A polite agent that says goodbye to the user.", + Instruction: "Directly return goodbye", + }, + ModelExtraConfig: map[string]any{ + "extra_body": map[string]any{ + "thinking": map[string]string{ + "type": "disabled", + }, + }, + }, + }) + if err != nil { + fmt.Printf("NewLLMAgent goodbyeAgent failed: %v", err) + return + } + + rootAgent, err := sequentialagent.New(sequentialagent.Config{ + AgentConfig: agent.Config{ + Name: "veAgent", + SubAgents: []agent.Agent{greetingAgent, goodbyeAgent}, + Description: "Executes a sequence of greeting and goodbye.", + }, + }) + + if err != nil { + fmt.Printf("NewSequentialAgent failed: %v", err) + return + } + + app := agentkit_server_app.NewAgentkitServerApp(apps.DefaultApiConfig()) + + err = app.Run(ctx, &apps.RunConfig{ + AgentLoader: agent.NewSingleLoader(rootAgent), + SessionService: session.InMemoryService(), + PluginConfig: runner.PluginConfig{ + Plugins: []*plugin.Plugin{ + //NewTestPlugins(), + utils.Must(builtin_tools.NewLLMShieldPlugins()), + }, + }, + }) + if err != nil { + fmt.Printf("Run failed: %v", err) + } +} + +func beforeModelCallBack(ctx agent.CallbackContext, llmRequest *model.LLMRequest) (*model.LLMResponse, error) { + fmt.Printf("%s BeforeModelCallBack called\n", ctx.AgentName()) + return nil, nil +} + +func afterModelCallBack(ctx agent.CallbackContext, llmResponse *model.LLMResponse, llmResponseError error) (*model.LLMResponse, error) { + fmt.Printf("%s afterModelCallback called\n", ctx.AgentName()) + return nil, nil +} + +func beforeToolCallback(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { + fmt.Printf("%s beforeToolCallBack called\n", tool.Name()) + return nil, nil +} + +func afterToolCallback(ctx tool.Context, tool tool.Tool, args, result map[string]any, err error) (map[string]any, error) { + fmt.Printf("%s afterToolCallback called\n", tool.Name()) + return nil, nil +} + +func NewTestPlugins() *plugin.Plugin { + plugins, _ := plugin.New(plugin.Config{ + Name: "llm_shield_test", + BeforeModelCallback: beforeModelCallBack, + AfterModelCallback: afterModelCallBack, + BeforeToolCallback: beforeToolCallback, + AfterToolCallback: afterToolCallback, + }) + return plugins +} diff --git a/go.mod b/go.mod index 5a4994d..d5dcd9f 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/volcengine/volcengine-go-sdk v1.1.53 go.uber.org/zap v1.27.1 golang.org/x/oauth2 v0.32.0 - google.golang.org/adk v0.3.1-0.20260113130926-012d380e4056 + google.golang.org/adk v0.3.1-0.20260123125504-aec89487da29 google.golang.org/genai v1.40.0 gopkg.in/go-playground/validator.v8 v8.18.2 gopkg.in/yaml.v3 v3.0.1 @@ -81,4 +81,4 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect rsc.io/omap v1.2.0 // indirect rsc.io/ordered v1.1.1 // indirect -) \ No newline at end of file +) diff --git a/go.sum b/go.sum index 85c1471..465e1c7 100644 --- a/go.sum +++ b/go.sum @@ -220,8 +220,8 @@ golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/adk v0.3.1-0.20260113130926-012d380e4056 h1:SfDR+nm6VdVjSsvKbHGAV/B6csmMge24nqQBSbJi5Go= -google.golang.org/adk v0.3.1-0.20260113130926-012d380e4056/go.mod h1:iE1Kgc8JtYHiNxfdLa9dxcV4DqTn0D8q4eqhBi012Ak= +google.golang.org/adk v0.3.1-0.20260123125504-aec89487da29 h1:xYzqSSpn/WhNVyiieXnvYLGKPRJk4cJFgImrhsA9Pi4= +google.golang.org/adk v0.3.1-0.20260123125504-aec89487da29/go.mod h1:jVeb7Ir53+3XKTncdY7k3pVdPneKcm5+60sXpxHQnao= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genai v1.40.0 h1:kYxyQSH+vsib8dvsgyLJzsVEIv5k3ZmHJyVqdvGncmc= diff --git a/tool/builtin_tools/llm_shield.go b/tool/builtin_tools/llm_shield.go new file mode 100644 index 0000000..c955450 --- /dev/null +++ b/tool/builtin_tools/llm_shield.go @@ -0,0 +1,407 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin_tools + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/volcengine/veadk-go/auth/veauth" + "github.com/volcengine/veadk-go/common" + "github.com/volcengine/veadk-go/configs" + "github.com/volcengine/veadk-go/integrations/ve_sign" + "github.com/volcengine/veadk-go/utils" + "google.golang.org/adk/agent" + "google.golang.org/adk/model" + "google.golang.org/adk/plugin" + "google.golang.org/adk/tool" + "google.golang.org/genai" +) + +const ( + name = "LLMShield" + path = "/v2/moderate" + service = "llmshield" + action = "Moderate" + version = "2025-08-31" + defaultTimeout = 60 +) + +var ( + ErrInvalidAppID = errors.New("LLM Shield App ID is not configured. Please configure it via environment TOOL_LLM_SHIELD_APP_ID") + ErrInvalidApiKey = errors.New("LLM Shield auth invalid, Please configure it via environment TOOL_LLM_SHIELD_API_KEY or VOLCENGINE_ACCESS_KEY and VOLCENGINE_SECRET_KEY") +) + +var CategoryMap = map[string]string{ + "101": "Model Misuse", + "103": "Sensitive Information", + "104": "Prompt Injection", + "106": "General Topic Control", + "107": "Computational Resource Consumption", +} + +type LLMShieldClient struct { + URL string + Region string + AppID string + APIKey string + Timeout int +} + +type LLMShieldResult struct { + ResponseMetadata *ResponseMetadata `json:"ResponseMetadata"` + Result *LLMShieldData `json:"Result"` +} +type ResponseMetadata struct { + RequestID string `json:"RequestId"` + Service string `json:"Service"` + Region string `json:"Region"` + Action string `json:"Action"` + Version string `json:"Version"` +} +type Matches struct { + Word string `json:"Word"` + Source int `json:"Source"` +} +type Risks struct { + Category string `json:"Category"` + Label string `json:"Label"` + Prob float64 `json:"Prob,omitempty"` + Matches []*Matches `json:"Matches,omitempty"` +} +type RiskInfo struct { + Risks []*Risks `json:"Risks"` +} + +type ReplaceDetail struct { + Replacement interface{} `json:"Replacement"` +} +type DecisionDetail struct { + BlockDetail map[string]interface{} `json:"BlockDetail"` + ReplaceDetail *ReplaceDetail `json:"ReplaceDetail"` +} +type Decision struct { + DecisionType int `json:"DecisionType"` + DecisionDetail *DecisionDetail `json:"DecisionDetail"` + HitStrategyIDs []string `json:"HitStrategyIDs"` +} +type PermitInfo struct { + Permits interface{} `json:"Permits"` +} +type LLMShieldData struct { + MsgID string `json:"MsgID"` + RiskInfo *RiskInfo `json:"RiskInfo"` + Decision *Decision `json:"Decision"` + PermitInfo *PermitInfo `json:"PermitInfo"` + ContentInfo string `json:"ContentInfo"` + Degraded bool `json:"Degraded"` + DegradeReason string `json:"DegradeReason"` +} + +func NewLLMShieldClient(timeout int) (*LLMShieldClient, error) { + if timeout <= 0 { + timeout = defaultTimeout + } + region := utils.GetEnvWithDefault(common.TOOL_LLM_SHIELD_REGION, configs.GetGlobalConfig().Tool.LLMShield.Region, common.DEFAULT_LLM_SHIELD_REGION) + shieldURL := utils.GetEnvWithDefault(common.TOOL_LLM_SHIELD_URL, configs.GetGlobalConfig().Tool.LLMShield.Url, fmt.Sprintf("https://%s.sdk.access.llm-shield.omini-shield.com", region)) + appId := utils.GetEnvWithDefault(common.TOOL_LLM_SHIELD_APP_ID, configs.GetGlobalConfig().Tool.LLMShield.AppId) + if strings.TrimSpace(appId) == "" { + return nil, ErrInvalidAppID + } + apiKey := utils.GetEnvWithDefault(common.TOOL_LLM_SHIELD_API_KEY, configs.GetGlobalConfig().Tool.LLMShield.ApiKey) + if strings.TrimSpace(apiKey) == "" { + ak, sk, _ := veauth.GetAuthInfo() + if strings.TrimSpace(ak) == "" || strings.TrimSpace(sk) == "" { + return nil, ErrInvalidApiKey + } + } + return &LLMShieldClient{ + AppID: appId, + APIKey: apiKey, + Region: region, + URL: shieldURL, + Timeout: timeout, + }, nil +} + +// requestLLMShield 向 LLM Shield 服务发送请求进行内容审核 +func (p *LLMShieldClient) requestLLMShield(message string, role string) (string, error) { + + body := map[string]interface{}{ + "Message": map[string]interface{}{ + "Role": role, + "Content": message, + "ContentType": 1, + }, + "Scene": p.AppID, + } + + var respBody []byte + + if p.APIKey != "" { + bodyBytes, _ := json.Marshal(body) + req, err := http.NewRequest("POST", p.URL+path, strings.NewReader(string(bodyBytes))) + if err != nil { + return "", err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", p.APIKey) + + q := req.URL.Query() + q.Add("Action", action) + q.Add("Version", version) + req.URL.RawQuery = q.Encode() + + client := &http.Client{Timeout: time.Duration(p.Timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != 200 { + return "", fmt.Errorf("LLM Shield HTTP error: %d", resp.StatusCode) + } + respBody, _ = io.ReadAll(resp.Body) + + } else { + ak, sk, sessionToken := veauth.GetAuthInfo() + if strings.TrimSpace(ak) == "" || strings.TrimSpace(sk) == "" { + return "", ErrInvalidApiKey + } + + header := map[string]string{ + "X-Top-Service": service, + "X-Top-Region": p.Region, + } + + if strings.TrimSpace(sessionToken) != "" { + header["X-Session-Token"] = sessionToken + } + + parsedURL, err := url.Parse(p.URL) + if err != nil { + return "", fmt.Errorf("invalid URL: %v", err) + } + + veReq := ve_sign.VeRequest{ + AK: ak, + SK: sk, + Method: "POST", + Scheme: parsedURL.Scheme, + Host: parsedURL.Host, + Path: path, + Service: service, + Region: p.Region, + Action: action, + Version: version, + Body: body, + Timeout: uint(p.Timeout), + Header: header, + } + + respBody, err = veReq.DoRequest() + if err != nil { + return "", fmt.Errorf("LLM Shield request failed: %v", err) + } + } + // 解析响应 + var response LLMShieldResult + + if err := json.Unmarshal(respBody, &response); err != nil { + return "", fmt.Errorf("JSON decode failed: %v", err) + } + + if response.Result != nil && response.Result.Decision != nil { + + if response.Result.Decision.DecisionType == 2 && response.Result.RiskInfo != nil { + risks := response.Result.RiskInfo.Risks + if len(risks) > 0 { + var riskReasons []string + seen := make(map[string]bool) + + for _, risk := range risks { + catName, ok := CategoryMap[risk.Category] + if !ok { + catName = fmt.Sprintf("Category %s", risk.Category) + } + if !seen[catName] { + riskReasons = append(riskReasons, catName) + seen[catName] = true + } + } + + reasonText := "security policy violation" + if len(riskReasons) > 0 { + reasonText = strings.Join(riskReasons, ", ") + } + + return fmt.Sprintf("Your request has been blocked due to: %s. Please modify your input and try again.", reasonText), nil + } + } + } + + return "", nil +} + +// -------------------- Callbacks -------------------- + +func NewLLMShieldPlugins() (*plugin.Plugin, error) { + c, err := NewLLMShieldClient(defaultTimeout) + if err != nil { + return nil, err + } + plugins, _ := plugin.New(plugin.Config{ + Name: "llm_shield", + BeforeModelCallback: c.beforeModelCallBack, + AfterModelCallback: c.afterModelCallBack, + BeforeToolCallback: c.beforeToolCallback, + AfterToolCallback: c.afterToolCallback, + }) + return plugins, nil +} + +// BeforeModelCallback 在发送给模型前检查用户输入 +func (p *LLMShieldClient) beforeModelCallBack(ctx agent.CallbackContext, req *model.LLMRequest) (*model.LLMResponse, error) { + var lastUserMessage string + var messageBuilder strings.Builder + + if len(req.Contents) > 0 { + lastContent := req.Contents[len(req.Contents)-1] + if lastContent.Role == "user" && len(lastContent.Parts) > 0 { + for _, part := range lastContent.Parts { + messageBuilder.WriteString(part.Text) + } + } + } + + lastUserMessage = messageBuilder.String() + if lastUserMessage == "" { + return nil, nil + } + + log.Printf("agent %s beforeModelCallBack lastUserMessage is %s\n", ctx.AgentName(), lastUserMessage) + + blockMsg, err := p.requestLLMShield(lastUserMessage, "user") + if err != nil { + log.Printf("LLM Shield beforeModelCallBack error: %v\n", err) + return nil, nil + } + + if blockMsg != "" { + return &model.LLMResponse{ + Content: &genai.Content{ + Role: "model", + Parts: []*genai.Part{ + {Text: blockMsg}, + }, + }, + Partial: false, + FinishReason: "STOP", + }, nil + } + + return nil, nil +} + +// AfterModelCallback 在返回给用户前检查模型输出 +func (p *LLMShieldClient) afterModelCallBack(ctx agent.CallbackContext, resp *model.LLMResponse, llmResponseError error) (*model.LLMResponse, error) { + var lastModelMessage string + if resp.Content.Role == "model" && len(resp.Content.Parts) > 0 { + lastModelMessage = resp.Content.Parts[0].Text + } + + if lastModelMessage == "" { + return nil, nil + } + + log.Printf("agent %s afterModelCallBack lastUserMessage is %s\n", ctx.AgentName(), lastModelMessage) + + blockMsg, err := p.requestLLMShield(lastModelMessage, "assistant") + if err != nil { + log.Printf("LLM Shield afterModelCallBack error: %v\n", err) + return nil, nil + } + + log.Printf("agent %s beforeModelCallBack blockMsg is %s\n", ctx.AgentName(), blockMsg) + + if blockMsg != "" { + return &model.LLMResponse{ + Content: &genai.Content{ + Role: "model", + Parts: []*genai.Part{ + {Text: blockMsg}, + }, + }, + Partial: false, + FinishReason: "STOP", + }, nil + } + + return nil, nil +} + +// BeforeToolCallback 在工具执行前检查参数 +func (p *LLMShieldClient) beforeToolCallback(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { + var argsList []string + for k, v := range args { + argsList = append(argsList, fmt.Sprintf("%s: %v", k, v)) + } + message := strings.Join(argsList, "\n") + + blockMsg, err := p.requestLLMShield(message, "user") + if err != nil { + log.Printf("LLM Shield beforeToolCallback error: %v\n", err) + return nil, nil + } + + if blockMsg != "" { + return map[string]interface{}{"result": blockMsg}, nil + } + return nil, nil +} + +// AfterToolCallback 在工具执行后检查结果 +func (p *LLMShieldClient) afterToolCallback(ctx tool.Context, tool tool.Tool, args, result map[string]any, err error) (map[string]any, error) { + if err != nil { + return result, err + } + var message string + + for _, item := range result { + message += fmt.Sprintf("%v\n", item) + } + + blockMsg, err := p.requestLLMShield(message, "assistant") + if err != nil { + log.Printf("LLM Shield beforeToolCallback error: %v\n", err) + return nil, nil + } + + if blockMsg != "" { + return map[string]interface{}{"result": blockMsg}, nil + } + return result, nil +} diff --git a/tool/builtin_tools/llm_shield_test.go b/tool/builtin_tools/llm_shield_test.go new file mode 100644 index 0000000..4adcce3 --- /dev/null +++ b/tool/builtin_tools/llm_shield_test.go @@ -0,0 +1,46 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin_tools + +import ( + "log" + "strings" + "testing" + + "github.com/volcengine/veadk-go/auth/veauth" + "github.com/volcengine/veadk-go/configs" +) + +func TestNewLLMShieldClient(t *testing.T) { + ak, sk, _ := veauth.GetAuthInfo() + if strings.TrimSpace(ak) == "" || strings.TrimSpace(sk) == "" { + t.Skip("AK or SK is empty") + } + err := configs.SetupVeADKConfig() + if err != nil { + log.Printf("veadk.SetupVeADKConfig: %v", err) + } + client, err := NewLLMShieldClient(60) + if err != nil { + t.Fatal(err) + return + } + result, err := client.requestLLMShield("网上都说A地很多骗子和小偷,他们的典型伎俩...", "user") + if err != nil { + t.Fatal("requestLLMShield error:", err) + return + } + t.Log("requestLLMShield result:", result) +}