diff --git a/.gitignore b/.gitignore index 147b027..1e7e589 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,8 @@ go.work .idea/ .trae agent/.env -examples/quickstart/config.yaml \ No newline at end of file +examples/quickstart/config.yaml + + +# MacOS system file +.DS_Store \ No newline at end of file diff --git a/configs/config_test.go b/configs/config_test.go index 6d6c8de..0fea85c 100644 --- a/configs/config_test.go +++ b/configs/config_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/volcengine/veadk-go/common" "github.com/volcengine/veadk-go/utils" + "gopkg.in/yaml.v3" ) func Test_loadConfigFromProjectEnv(t *testing.T) { @@ -91,3 +92,59 @@ func TestSetupVeADKConfig(t *testing.T) { _ = SetupVeADKConfig() assert.Equal(t, "doubao-seed-1-6-250615", os.Getenv(common.MODEL_AGENT_NAME)) } + +func TestObservabilityConfig_YamlMapping(t *testing.T) { + yamlData := ` +opentelemetry: + apmplus: + endpoint: "http://apmplus-example.com" + api_key: "test-key" + service_name: "test-service" + enable_global_tracer: true +` + var config ObservabilityConfig + err := yaml.Unmarshal([]byte(yamlData), &config) + assert.NoError(t, err) + + assert.NotNil(t, config.OpenTelemetry) + assert.NotNil(t, config.OpenTelemetry.ApmPlus) + assert.Equal(t, "http://apmplus-example.com", config.OpenTelemetry.ApmPlus.Endpoint) + assert.Equal(t, "test-key", config.OpenTelemetry.ApmPlus.APIKey) + assert.Equal(t, "test-service", config.OpenTelemetry.ApmPlus.ServiceName) + assert.True(t, config.OpenTelemetry.EnableGlobalProvider) + + assert.Equal(t, "test-service", config.OpenTelemetry.ApmPlus.ServiceName) + assert.True(t, config.OpenTelemetry.EnableGlobalProvider) +} + +func TestObservabilityConfig_EnvMapping(t *testing.T) { + os.Setenv("OBSERVABILITY_OPENTELEMETRY_APMPLUS_ENDPOINT", "http://env-endpoint") + os.Setenv("OBSERVABILITY_OPENTELEMETRY_ENABLE_GLOBAL_PROVIDER", "true") + defer func() { + os.Unsetenv("OBSERVABILITY_OPENTELEMETRY_APMPLUS_ENDPOINT") + os.Unsetenv("OBSERVABILITY_OPENTELEMETRY_ENABLE_GLOBAL_PROVIDER") + }() + + config := &ObservabilityConfig{} + config.MapEnvToConfig() + + assert.NotNil(t, config.OpenTelemetry) + assert.NotNil(t, config.OpenTelemetry.ApmPlus) + assert.Equal(t, "http://env-endpoint", config.OpenTelemetry.ApmPlus.Endpoint) + assert.True(t, config.OpenTelemetry.EnableGlobalProvider) +} + +func TestObservabilityConfig_Priority(t *testing.T) { + // Nested priority check: CozeLoop > APMPlus + config := &ObservabilityConfig{ + OpenTelemetry: &OpenTelemetryConfig{ + ApmPlus: &ApmPlusConfig{ + Endpoint: "apm-endpoint", + }, + CozeLoop: &CozeLoopExporterConfig{ + Endpoint: "coze-endpoint", + }, + }, + } + assert.NotNil(t, config.OpenTelemetry.CozeLoop) +} diff --git a/configs/configs.go b/configs/configs.go index 3ecc060..c23da2e 100644 --- a/configs/configs.go +++ b/configs/configs.go @@ -21,34 +21,40 @@ import ( "strconv" "strings" + "sync" + "github.com/joho/godotenv" "gopkg.in/yaml.v3" ) type VeADKConfig struct { - Volcengine *Volcengine `yaml:"volcengine"` - Model *ModelConfig `yaml:"model"` - Tool *BuiltinToolConfigs `yaml:"tools"` - PromptPilot *PromptPilotConfig `yaml:"prompt_pilot"` - CozeLoopConfig *CozeLoopConfig `yaml:"coze_loop"` - TlsConfig *TLSConfig `yaml:"tls_config"` - Veidentity *VeIdentityConfig `yaml:"veidentity"` - Database *DatabaseConfig `yaml:"database"` - LOGGING *Logging `yaml:"LOGGING"` + Volcengine *Volcengine `yaml:"volcengine"` + Model *ModelConfig `yaml:"model"` + Tool *BuiltinToolConfigs `yaml:"tools"` + PromptPilot *PromptPilotConfig `yaml:"prompt_pilot"` + CozeLoopConfig *CozeLoopConfig `yaml:"coze_loop"` + TlsConfig *TLSConfig `yaml:"tls_config"` + Veidentity *VeIdentityConfig `yaml:"veidentity"` + Database *DatabaseConfig `yaml:"database"` + LOGGING *Logging `yaml:"LOGGING"` + Observability *ObservabilityConfig `yaml:"observability"` } type EnvConfigMaptoStruct interface { MapEnvToConfig() // 用于映射环境变量到结构体字段 } -var globalConfig *VeADKConfig +var ( + globalConfig *VeADKConfig + configOnce sync.Once +) func GetGlobalConfig() *VeADKConfig { - if globalConfig == nil { + configOnce.Do(func() { if err := SetupVeADKConfig(); err != nil { panic(err) } - } + }) return globalConfig } @@ -83,6 +89,12 @@ func SetupVeADKConfig() error { TOS: &TosClientConf{}, Mem0: &Mem0Config{}, }, + Observability: &ObservabilityConfig{ + OpenTelemetry: &OpenTelemetryConfig{ + EnableGlobalProvider: true, // use global trace provider by default, like veadk-python + EnableLocalProvider: false, // disable adk-go's local provider + }, + }, } globalConfig.Model.MapEnvToConfig() globalConfig.Tool.MapEnvToConfig() @@ -91,6 +103,7 @@ func SetupVeADKConfig() error { globalConfig.LOGGING.MapEnvToConfig() globalConfig.Database.MapEnvToConfig() globalConfig.Volcengine.MapEnvToConfig() + globalConfig.Observability.MapEnvToConfig() return nil } @@ -153,6 +166,10 @@ func setYamlToEnv(data map[string]interface{}, prefix string) { if os.Getenv(fullKey) == "" { _ = os.Setenv(fullKey, strconv.Itoa(v)) } + case bool: + if os.Getenv(fullKey) == "" { + _ = os.Setenv(fullKey, strconv.FormatBool(v)) + } } } } diff --git a/configs/observability.go b/configs/observability.go new file mode 100644 index 0000000..440effe --- /dev/null +++ b/configs/observability.go @@ -0,0 +1,332 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configs + +import ( + "os" + + "github.com/volcengine/veadk-go/utils" +) + +const ( + // Global + EnvOtelServiceName = "OTEL_SERVICE_NAME" + EnvObservabilityEnableLocalProvider = "OBSERVABILITY_OPENTELEMETRY_ENABLE_LOCAL_PROVIDER" + EnvObservabilityEnableGlobalProvider = "OBSERVABILITY_OPENTELEMETRY_ENABLE_GLOBAL_PROVIDER" + EnvObservabilityEnableMetrics = "OBSERVABILITY_OPENTELEMETRY_ENABLE_METRICS" + + // APMPlus + EnvObservabilityOpenTelemetryApmPlusProtocol = "OBSERVABILITY_OPENTELEMETRY_APMPLUS_PROTOCOL" + EnvObservabilityOpenTelemetryApmPlusEndpoint = "OBSERVABILITY_OPENTELEMETRY_APMPLUS_ENDPOINT" + EnvObservabilityOpenTelemetryApmPlusAPIKey = "OBSERVABILITY_OPENTELEMETRY_APMPLUS_API_KEY" + EnvObservabilityOpenTelemetryApmPlusServiceName = "OBSERVABILITY_OPENTELEMETRY_APMPLUS_SERVICE_NAME" + + // CozeLoop + EnvObservabilityOpenTelemetryCozeLoopEndpoint = "OBSERVABILITY_OPENTELEMETRY_COZELOOP_ENDPOINT" + EnvObservabilityOpenTelemetryCozeLoopAPIKey = "OBSERVABILITY_OPENTELEMETRY_COZELOOP_API_KEY" + EnvObservabilityOpenTelemetryCozeLoopServiceName = "OBSERVABILITY_OPENTELEMETRY_COZELOOP_SERVICE_NAME" + + // TLS + EnvObservabilityOpenTelemetryTLSEndpoint = "OBSERVABILITY_OPENTELEMETRY_TLS_ENDPOINT" + EnvObservabilityOpenTelemetryTLSServiceName = "OBSERVABILITY_OPENTELEMETRY_TLS_SERVICE_NAME" + EnvObservabilityOpenTelemetryTLSRegion = "OBSERVABILITY_OPENTELEMETRY_TLS_REGION" + EnvObservabilityOpenTelemetryTLSTopicID = "OBSERVABILITY_OPENTELEMETRY_TLS_TOPIC_ID" + EnvObservabilityOpenTelemetryTLSAccessKey = "OBSERVABILITY_OPENTELEMETRY_TLS_ACCESS_KEY" + EnvObservabilityOpenTelemetryTLSSecretKey = "OBSERVABILITY_OPENTELEMETRY_TLS_SECRET_KEY" + + // File + EnvObservabilityOpenTelemetryFilePath = "OBSERVABILITY_OPENTELEMETRY_FILE_PATH" + + // Stdout + EnvObservabilityOpenTelemetryStdoutEnable = "OBSERVABILITY_OPENTELEMETRY_STDOUT_ENABLE" +) + +// ObservabilityConfig groups specific configurations for different platforms. +type ObservabilityConfig struct { + OpenTelemetry *OpenTelemetryConfig `yaml:"opentelemetry"` +} + +type OpenTelemetryConfig struct { + EnableLocalProvider bool `yaml:"enable_local_tracer"` + EnableGlobalProvider bool `yaml:"enable_global_tracer"` + EnableMetrics *bool `yaml:"enable_metrics"` + + File *FileConfig `yaml:"file"` + Stdout *StdoutConfig `yaml:"stdout"` + ApmPlus *ApmPlusConfig `yaml:"apmplus"` + CozeLoop *CozeLoopExporterConfig `yaml:"cozeloop"` + TLS *TLSExporterConfig `yaml:"tls"` +} + +type ApmPlusConfig struct { + Protocol string `yaml:"protocol"` // grpc by default + Endpoint string `yaml:"endpoint"` + APIKey string `yaml:"api_key"` + ServiceName string `yaml:"service_name"` +} + +type CozeLoopExporterConfig struct { + Endpoint string `yaml:"endpoint"` + APIKey string `yaml:"api_key"` + ServiceName string `yaml:"service_name"` +} + +type TLSExporterConfig struct { + Endpoint string `yaml:"endpoint"` + ServiceName string `yaml:"service_name"` + Region string `yaml:"region"` + TopicID string `yaml:"topic_id"` + AccessKey string `yaml:"access_key"` + SecretKey string `yaml:"secret_key"` +} + +type FileConfig struct { + Path string `yaml:"path"` +} + +type StdoutConfig struct { + Enable bool `yaml:"enable"` +} + +func (c *ObservabilityConfig) MapEnvToConfig() { + if c.OpenTelemetry == nil { + c.OpenTelemetry = &OpenTelemetryConfig{} + } + ot := c.OpenTelemetry + + // APMPlus + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryApmPlusEndpoint); v != "" { + if ot.ApmPlus == nil { + ot.ApmPlus = &ApmPlusConfig{} + } + + ot.ApmPlus.Endpoint = v + + if ot.EnableMetrics == nil { + ot.EnableMetrics = new(bool) + *ot.EnableMetrics = true + } + } + + // APMPlus Protocol + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryApmPlusProtocol); v != "" { + if ot.ApmPlus == nil { + ot.ApmPlus = &ApmPlusConfig{} + } + ot.ApmPlus.Protocol = v + } + + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryApmPlusAPIKey); v != "" { + if ot.ApmPlus == nil { + ot.ApmPlus = &ApmPlusConfig{} + } + if ot.ApmPlus.APIKey == "" { + ot.ApmPlus.APIKey = v + } + } + + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryApmPlusServiceName); v != "" { + if ot.ApmPlus == nil { + ot.ApmPlus = &ApmPlusConfig{} + } + ot.ApmPlus.ServiceName = v + if os.Getenv(EnvOtelServiceName) == "" { + os.Setenv(EnvOtelServiceName, v) + } + } + + // CozeLoop + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryCozeLoopEndpoint); v != "" { + if ot.CozeLoop == nil { + ot.CozeLoop = &CozeLoopExporterConfig{} + } + ot.CozeLoop.Endpoint = v + } + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryCozeLoopAPIKey); v != "" { + if ot.CozeLoop == nil { + ot.CozeLoop = &CozeLoopExporterConfig{} + } + ot.CozeLoop.APIKey = v + } + + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryCozeLoopServiceName); v != "" { + if ot.CozeLoop == nil { + ot.CozeLoop = &CozeLoopExporterConfig{} + } + ot.CozeLoop.ServiceName = v + if os.Getenv(EnvOtelServiceName) == "" { + os.Setenv(EnvOtelServiceName, v) + } + } + + // TLS + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryTLSEndpoint); v != "" { + if ot.TLS == nil { + ot.TLS = &TLSExporterConfig{} + } + ot.TLS.Endpoint = v + } + + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryTLSServiceName); v != "" { + if ot.TLS == nil { + ot.TLS = &TLSExporterConfig{} + } + ot.TLS.ServiceName = v + if os.Getenv(EnvOtelServiceName) == "" { + os.Setenv(EnvOtelServiceName, v) + } + } + + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryTLSRegion); v != "" { + if ot.TLS == nil { + ot.TLS = &TLSExporterConfig{} + } + ot.TLS.Region = v + } + + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryTLSTopicID); v != "" { + if ot.TLS == nil { + ot.TLS = &TLSExporterConfig{} + } + ot.TLS.TopicID = v + } + + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryTLSAccessKey); v != "" { + if ot.TLS == nil { + ot.TLS = &TLSExporterConfig{} + } + ot.TLS.AccessKey = v + } + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryTLSSecretKey); v != "" { + if ot.TLS == nil { + ot.TLS = &TLSExporterConfig{} + } + ot.TLS.SecretKey = v + } + + // File + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryFilePath); v != "" { + if ot.File == nil { + ot.File = &FileConfig{} + } + ot.File.Path = v + } + + if v := utils.GetEnvWithDefault(EnvObservabilityOpenTelemetryStdoutEnable); v != "" { + if ot.Stdout == nil { + ot.Stdout = &StdoutConfig{} + } + ot.Stdout.Enable = v == "true" + } + + // Global Tracer + if v := utils.GetEnvWithDefault(EnvObservabilityEnableGlobalProvider); v != "" { + ot.EnableGlobalProvider = v == "true" + } + + // Local Tracer + if v := utils.GetEnvWithDefault(EnvObservabilityEnableLocalProvider); v != "" { + ot.EnableLocalProvider = v == "true" + } + + // Meter Provider + if v := utils.GetEnvWithDefault(EnvObservabilityEnableMetrics); v != "" { + if ot.EnableMetrics == nil { + ot.EnableMetrics = new(bool) + } + *ot.EnableMetrics = v == "true" + } +} + +func (c *ObservabilityConfig) Clone() *ObservabilityConfig { + if c == nil { + return nil + } + return &ObservabilityConfig{ + OpenTelemetry: c.OpenTelemetry.Clone(), + } +} + +func (c *OpenTelemetryConfig) Clone() *OpenTelemetryConfig { + if c == nil { + return nil + } + + return &OpenTelemetryConfig{ + EnableGlobalProvider: c.EnableGlobalProvider, + EnableLocalProvider: c.EnableLocalProvider, + EnableMetrics: c.EnableMetrics, + ApmPlus: c.ApmPlus.Clone(), + CozeLoop: c.CozeLoop.Clone(), + TLS: c.TLS.Clone(), + File: c.File.Clone(), + Stdout: c.Stdout.Clone(), + } +} + +func (c *ApmPlusConfig) Clone() *ApmPlusConfig { + if c == nil { + return nil + } + return &ApmPlusConfig{ + Endpoint: c.Endpoint, + Protocol: c.Protocol, + APIKey: c.APIKey, + ServiceName: c.ServiceName, + } +} + +func (c *CozeLoopExporterConfig) Clone() *CozeLoopExporterConfig { + if c == nil { + return nil + } + return &CozeLoopExporterConfig{ + Endpoint: c.Endpoint, + ServiceName: c.ServiceName, + APIKey: c.APIKey, + } +} + +func (c *TLSExporterConfig) Clone() *TLSExporterConfig { + if c == nil { + return nil + } + return &TLSExporterConfig{ + Endpoint: c.Endpoint, + ServiceName: c.ServiceName, + Region: c.Region, + TopicID: c.TopicID, + AccessKey: c.AccessKey, + SecretKey: c.SecretKey, + } +} + +func (c *FileConfig) Clone() *FileConfig { + if c == nil { + return nil + } + return &FileConfig{ + Path: c.Path, + } +} + +func (c *StdoutConfig) Clone() *StdoutConfig { + if c == nil { + return nil + } + return &StdoutConfig{ + Enable: c.Enable, + } +} diff --git a/examples/observability/agent.go b/examples/observability/agent.go new file mode 100644 index 0000000..dc0a60b --- /dev/null +++ b/examples/observability/agent.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "fmt" + "os" + + veagent "github.com/volcengine/veadk-go/agent/llmagent" + "github.com/volcengine/veadk-go/common" + "github.com/volcengine/veadk-go/observability" + "github.com/volcengine/veadk-go/tool/builtin_tools/web_search" + "github.com/volcengine/veadk-go/utils" + "google.golang.org/adk/agent" + "google.golang.org/adk/cmd/launcher" + "google.golang.org/adk/cmd/launcher/full" + "google.golang.org/adk/plugin" + "google.golang.org/adk/runner" + "google.golang.org/adk/session" + "google.golang.org/adk/tool" +) + +func main() { + ctx := context.Background() + // Shutdown to flush spans and metrics + defer observability.Shutdown(ctx) + + cfg := &veagent.Config{ + ModelName: common.DEFAULT_MODEL_AGENT_NAME, + ModelAPIBase: common.DEFAULT_MODEL_AGENT_API_BASE, + ModelAPIKey: utils.GetEnvWithDefault(common.MODEL_AGENT_API_KEY), + } + + webSearch, err := web_search.NewWebSearchTool(&web_search.Config{}) + if err != nil { + fmt.Printf("NewLLMAgent failed: %v", err) + return + } + + cfg.Tools = []tool.Tool{webSearch} + + a, err := veagent.New(cfg) + if err != nil { + fmt.Printf("NewLLMAgent failed: %v", err) + return + } + + config := &launcher.Config{ + AgentLoader: agent.NewSingleLoader(a), + SessionService: session.InMemoryService(), + PluginConfig: runner.PluginConfig{ + Plugins: []*plugin.Plugin{observability.NewPlugin()}, + }, + } + + l := full.NewLauncher() + if err = l.Execute(ctx, config, os.Args[1:]); err != nil { + fmt.Printf("Run failed: %v\n\n%s", err, l.CommandLineSyntax()) + } +} diff --git a/examples/observability/config.yaml b/examples/observability/config.yaml new file mode 100644 index 0000000..e69de29 diff --git a/go.mod b/go.mod index 3a494bd..d71a783 100644 --- a/go.mod +++ b/go.mod @@ -14,9 +14,23 @@ require ( github.com/stretchr/testify v1.11.1 github.com/volcengine/ve-tos-golang-sdk/v2 v2.7.26 github.com/volcengine/volcengine-go-sdk v1.1.53 + go.opentelemetry.io/otel v1.39.0 + go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.15.0 + go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.15.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.39.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 + go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0 + go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.39.0 + go.opentelemetry.io/otel/metric v1.39.0 + go.opentelemetry.io/otel/sdk v1.39.0 + go.opentelemetry.io/otel/sdk/log v0.15.0 + go.opentelemetry.io/otel/sdk/metric v1.39.0 + go.opentelemetry.io/otel/trace v1.39.0 go.uber.org/zap v1.27.1 golang.org/x/oauth2 v0.32.0 - google.golang.org/adk v0.3.1-0.20260128143420-39c421031f6c + google.golang.org/adk v0.4.1-0.20260130112425-78c856d8d703 google.golang.org/genai v1.40.0 gopkg.in/go-playground/validator.v8 v8.18.2 gopkg.in/yaml.v3 v3.0.1 @@ -32,6 +46,8 @@ require ( github.com/bluele/gcache v0.0.2 // indirect github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/coze-dev/cozeloop-go/spec v0.1.4-0.20250829072213-3812ddbfb735 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -47,6 +63,7 @@ require ( github.com/googleapis/gax-go/v2 v2.15.0 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gorilla/websocket v1.5.3 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.6.0 // indirect @@ -75,21 +92,20 @@ require ( go.mongodb.org/mongo-driver v1.17.6 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect - go.opentelemetry.io/otel v1.38.0 // indirect - go.opentelemetry.io/otel/metric v1.38.0 // indirect - go.opentelemetry.io/otel/sdk v1.38.0 // indirect - go.opentelemetry.io/otel/trace v1.38.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect + go.opentelemetry.io/otel/log v0.14.0 // indirect + go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.uber.org/multierr v1.10.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.18.0 // indirect - golang.org/x/sys v0.38.0 // indirect + golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.31.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f // indirect - google.golang.org/grpc v1.76.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/grpc v1.77.0 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 42e14cb..a76d45f 100644 --- a/go.sum +++ b/go.sum @@ -21,7 +21,11 @@ github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPII github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= @@ -96,6 +100,8 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -195,16 +201,42 @@ go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= -go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= -go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= -go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= -go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= -go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= -go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= -go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= -go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= -go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= -go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.15.0 h1:W+m0g+/6v3pa5PgVf2xoFMi5YtNR06WtS7ve5pcvLtM= +go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.15.0/go.mod h1:JM31r0GGZ/GU94mX8hN4D8v6e40aFlUECSQ48HaLgHM= +go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.15.0 h1:EKpiGphOYq3CYnIe2eX9ftUkyU+Y8Dtte8OaWyHJ4+I= +go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.15.0/go.mod h1:nWFP7C+T8TygkTjJ7mAyEaFaE7wNfms3nV/vexZ6qt0= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0 h1:cEf8jF6WbuGQWUVcqgyWtTR0kOOAWY1DYZ+UhvdmQPw= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0/go.mod h1:k1lzV5n5U3HkGvTCJHraTAGJ7MqsgL1wrGwTj1Isfiw= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.39.0 h1:nKP4Z2ejtHn3yShBb+2KawiXgpn8In5cT7aO2wXuOTE= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.39.0/go.mod h1:NwjeBbNigsO4Aj9WgM0C+cKIrxsZUaRmZUO7A8I7u8o= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0/go.mod h1:vnakAaFckOMiMtOIhFI2MNH4FYrZzXCYxmb1LlhoGz8= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 h1:in9O8ESIOlwJAEGTkkf34DesGRAc/Pn8qJ7k3r/42LM= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0/go.mod h1:Rp0EXBm5tfnv0WL+ARyO/PHBEaEAT8UUHQ6AGJcSq6c= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 h1:Ckwye2FpXkYgiHX7fyVrN1uA/UYd9ounqqTuSNAv0k4= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0/go.mod h1:teIFJh5pW2y+AN7riv6IBPX2DuesS3HgP39mwOspKwU= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0 h1:5gn2urDL/FBnK8OkCfD1j3/ER79rUuTYmCvlXBKeYL8= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0/go.mod h1:0fBG6ZJxhqByfFZDwSwpZGzJU671HkwpWaNe2t4VUPI= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.39.0 h1:8UPA4IbVZxpsD76ihGOQiFml99GPAEZLohDXvqHdi6U= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.39.0/go.mod h1:MZ1T/+51uIVKlRzGw1Fo46KEWThjlCBZKl2LzY5nv4g= +go.opentelemetry.io/otel/log v0.14.0 h1:2rzJ+pOAZ8qmZ3DDHg73NEKzSZkhkGIua9gXtxNGgrM= +go.opentelemetry.io/otel/log v0.14.0/go.mod h1:5jRG92fEAgx0SU/vFPxmJvhIuDU9E1SUnEQrMlJpOno= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/log v0.15.0 h1:WgMEHOUt5gjJE93yqfqJOkRflApNif84kxoHWS9VVHE= +go.opentelemetry.io/otel/sdk/log v0.15.0/go.mod h1:qDC/FlKQCXfH5hokGsNg9aUBGMJQsrUyeOiW5u+dKBQ= +go.opentelemetry.io/otel/sdk/log/logtest v0.14.0 h1:Ijbtz+JKXl8T2MngiwqBlPaHqc4YCaP/i13Qrow6gAM= +go.opentelemetry.io/otel/sdk/log/logtest v0.14.0/go.mod h1:dCU8aEL6q+L9cYTqcVOk8rM9Tp8WdnHOPLiBgp0SGOA= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= +go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= @@ -241,8 +273,8 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -258,8 +290,8 @@ golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/adk v0.3.1-0.20260128143420-39c421031f6c h1:2nJELHOSVHL2GbjHyqMoMsiuGfDsyaD7i4kUene765U= -google.golang.org/adk v0.3.1-0.20260128143420-39c421031f6c/go.mod h1:jVeb7Ir53+3XKTncdY7k3pVdPneKcm5+60sXpxHQnao= +google.golang.org/adk v0.4.1-0.20260130112425-78c856d8d703 h1:BuwPyBd36S0fXKmhbua8uyvv6PFFneYahVx9CRunBW4= +google.golang.org/adk v0.4.1-0.20260130112425-78c856d8d703/go.mod h1:jVeb7Ir53+3XKTncdY7k3pVdPneKcm5+60sXpxHQnao= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genai v1.40.0 h1:kYxyQSH+vsib8dvsgyLJzsVEIv5k3ZmHJyVqdvGncmc= @@ -267,15 +299,15 @@ google.golang.org/genai v1.40.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5g google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0= -google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f h1:1FTH6cpXFsENbPR5Bu8NQddPSaUUE6NA2XdZdDSAJK4= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= -google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/model/openai.go b/model/openai.go index ada6788..502508e 100644 --- a/model/openai.go +++ b/model/openai.go @@ -99,6 +99,7 @@ func (m *openAIModel) GenerateContent(ctx context.Context, req *model.LLMRequest if stream { return m.generateStream(ctx, openaiReq) } + return m.generate(ctx, openaiReq) } @@ -111,10 +112,15 @@ type openAIRequest struct { TopP *float64 `json:"top_p,omitempty"` Stop []string `json:"stop,omitempty"` Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` ResponseFormat *responseFormat `json:"response_format,omitempty"` ExtraBody map[string]any } +type streamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + func (r openAIRequest) MarshalJSON() ([]byte, error) { topLevel := make(map[string]interface{}) @@ -141,6 +147,9 @@ func (r openAIRequest) MarshalJSON() ([]byte, error) { if r.ResponseFormat != nil { topLevel["response_format"] = r.ResponseFormat } + if r.StreamOptions != nil { + topLevel["stream_options"] = r.StreamOptions + } if r.ExtraBody != nil { for k, v := range r.ExtraBody { @@ -204,7 +213,9 @@ type choice struct { type usage struct { PromptTokens int `json:"prompt_tokens"` + InputTokens int `json:"input_tokens"` // Ark-compatible field CompletionTokens int `json:"completion_tokens"` + OutputTokens int `json:"output_tokens"` // Ark-compatible field TotalTokens int `json:"total_tokens"` PromptTokensDetails *promptTokensDetails `json:"prompt_tokens_details,omitempty"` } @@ -268,6 +279,8 @@ func (m *openAIModel) convertOpenAIRequest(req *model.LLMRequest) (*openAIReques } } + openaiReq.StreamOptions = &streamOptions{IncludeUsage: true} + return openaiReq, nil } @@ -529,10 +542,16 @@ func (m *openAIModel) generateStream(ctx context.Context, openaiReq *openAIReque }() scanner := bufio.NewScanner(httpResp.Body) + // Set a larger buffer for the scanner to handle long SSE lines + const maxScannerBuffer = 1 * 1024 * 1024 // 1MB + scanner.Buffer(make([]byte, 64*1024), maxScannerBuffer) + var textBuffer strings.Builder var reasoningBuffer strings.Builder var toolCalls []toolCall - var usage *usage + var finalUsage usage + var usageFound bool + var finishedReason string for scanner.Scan() { line := scanner.Text() @@ -550,11 +569,19 @@ func (m *openAIModel) generateStream(ctx context.Context, openaiReq *openAIReque continue } + if chunk.Usage != nil { + finalUsage = *chunk.Usage + usageFound = true + } + if len(chunk.Choices) == 0 { continue } choice := chunk.Choices[0] + if choice.FinishReason != "" { + finishedReason = choice.FinishReason + } delta := choice.Delta if delta == nil { continue @@ -597,8 +624,8 @@ func (m *openAIModel) generateStream(ctx context.Context, openaiReq *openAIReque } if len(delta.ToolCalls) > 0 { - for idx, tc := range delta.ToolCalls { - targetIdx := idx + for _, tc := range delta.ToolCalls { + targetIdx := 0 if tc.Index != nil { targetIdx = *tc.Index } @@ -614,19 +641,11 @@ func (m *openAIModel) generateStream(ctx context.Context, openaiReq *openAIReque if tc.Function.Name != "" { toolCalls[targetIdx].Function.Name += tc.Function.Name } - toolCalls[targetIdx].Function.Arguments += tc.Function.Arguments + if tc.Function.Arguments != "" { + toolCalls[targetIdx].Function.Arguments += tc.Function.Arguments + } } } - - if chunk.Usage != nil { - usage = chunk.Usage - } - - if choice.FinishReason != "" { - finalResp := m.buildFinalResponse(textBuffer.String(), reasoningBuffer.String(), toolCalls, usage, choice.FinishReason) - yield(finalResp, nil) - return - } } if err := scanner.Err(); err != nil { @@ -634,8 +653,15 @@ func (m *openAIModel) generateStream(ctx context.Context, openaiReq *openAIReque return } - if textBuffer.Len() > 0 || len(toolCalls) > 0 { - finalResp := m.buildFinalResponse(textBuffer.String(), reasoningBuffer.String(), toolCalls, usage, "stop") + if textBuffer.Len() > 0 || len(toolCalls) > 0 || finishedReason != "" || usageFound { + var u *usage + if usageFound { + u = &finalUsage + } + if finishedReason == "" { + finishedReason = "stop" + } + finalResp := m.buildFinalResponse(textBuffer.String(), reasoningBuffer.String(), toolCalls, u, finishedReason) yield(finalResp, nil) } } @@ -743,6 +769,9 @@ func (m *openAIModel) convertResponse(resp *response) (*model.LLMResponse, error Role: "model", Parts: parts, }, + CustomMetadata: map[string]any{ + "response_model": resp.Model, + }, } llmResp.UsageMetadata = buildUsageMetadata(resp.Usage) @@ -785,6 +814,9 @@ func (m *openAIModel) buildFinalResponse(text string, reasoningText string, tool }, FinishReason: mapFinishReason(finishReason), UsageMetadata: buildUsageMetadata(usage), + CustomMetadata: map[string]any{ + "response_model": m.name, + }, } return llmResp @@ -794,10 +826,24 @@ func buildUsageMetadata(usage *usage) *genai.GenerateContentResponseUsageMetadat if usage == nil { return nil } + + promptTokens := usage.PromptTokens + if promptTokens == 0 { + promptTokens = usage.InputTokens + } + completionTokens := usage.CompletionTokens + if completionTokens == 0 { + completionTokens = usage.OutputTokens + } + totalTokens := usage.TotalTokens + if totalTokens == 0 && (promptTokens > 0 || completionTokens > 0) { + totalTokens = promptTokens + completionTokens + } + metadata := &genai.GenerateContentResponseUsageMetadata{ - PromptTokenCount: int32(usage.PromptTokens), - CandidatesTokenCount: int32(usage.CompletionTokens), - TotalTokenCount: int32(usage.TotalTokens), + PromptTokenCount: int32(promptTokens), + CandidatesTokenCount: int32(completionTokens), + TotalTokenCount: int32(totalTokens), } if usage.PromptTokensDetails != nil { metadata.CachedContentTokenCount = int32(usage.PromptTokensDetails.CachedTokens) diff --git a/model/openai_test.go b/model/openai_test.go index 846a9fa..1438aec 100644 --- a/model/openai_test.go +++ b/model/openai_test.go @@ -201,6 +201,9 @@ func TestModel_Generate(t *testing.T) { CandidatesTokenCount: 5, TotalTokenCount: 15, }, + CustomMetadata: map[string]any{ + "response_model": "test-model", + }, FinishReason: genai.FinishReasonStop, }, }, @@ -224,6 +227,9 @@ func TestModel_Generate(t *testing.T) { CandidatesTokenCount: 5, TotalTokenCount: 15, }, + CustomMetadata: map[string]any{ + "response_model": "test-model", + }, FinishReason: genai.FinishReasonStop, }, }, diff --git a/observability/README.md b/observability/README.md new file mode 100644 index 0000000..53aa431 --- /dev/null +++ b/observability/README.md @@ -0,0 +1,134 @@ +# VeADK Go Observability Package + +This package provides comprehensive observability features for the VeADK Go SDK, fully aligned with the [VeADK Python SDK](https://volcengine.github.io/veadk-python/observation/span-attributes/) and [OpenTelemetry GenAI Semantic Conventions](https://opentelemetry.io/docs/specs/semconv/gen-ai/). + +## Features + +- **Python ADK Alignment**: Implements the same span attributes, event structures, and naming conventions as the Python ADK +- **Multi-Platform Support**: Simultaneously export traces to APMPlus, CozeLoop, Volcano TLS, or local files/stdout +- **Automatic Attribute Enrichment**: Automatically captures and propagates `SessionID`, `UserID`, `AppName`, `InvocationID` from context, config, or environment +- **Span Hierarchy Support**: Properly tracks invocation → agent → LLM/tool execution hierarchies +- **Metrics Support**: Automated recording of token usage, operation latencies, and first token latency + +## Span Attribute Specification + +VeADK Go implements the following span attribute categories as documented in [Python ADK Span Attributes](https://volcengine.github.io/veadk-python/observation/span-attributes/): + +### Common Attributes (All Spans) +- `gen_ai.system` - Model provider (e.g., "openai", "ark") +- `gen_ai.system.version` - VeADK version +- `gen_ai.agent.name` - Agent name +- `gen_ai.app.name` / `app_name` / `app.name` - Application name +- `gen_ai.user.id` / `user.id` - User identifier +- `gen_ai.session.id` / `session.id` - Session identifier +- `gen_ai.invocation.id` / `invocation.id` - Invocation identifier +- `cozeloop.report.source` - Fixed value "veadk" +- `cozeloop.call_type` - Call type for CozeLoop +- `openinference.instrumentation.veadk` - Instrumentation version + +### LLM Span Attributes +- `gen_ai.span.kind` - "llm" +- `gen_ai.operation.name` - "chat" +- `gen_ai.request.model` - Model name +- `gen_ai.request.type` - Request type +- `gen_ai.request.max_tokens` - Max output tokens +- `gen_ai.request.temperature` - Sampling temperature +- `gen_ai.request.top_p` - Top-p parameter +- `gen_ai.usage.input_tokens` - Input token count +- `gen_ai.usage.output_tokens` - Output token count +- `gen_ai.usage.total_tokens` - Total token count +- `gen_ai.prompt` - Input messages +- `gen_ai.completion` - Output messages +- `gen_ai.messages` - Complete message events +- `gen_ai.choice` - Model choices + +### Tool Span Attributes +- `gen_ai.span.kind` - "tool" +- `gen_ai.operation.name` - "execute_tool" +- `gen_ai.tool.name` - Tool name +- `gen_ai.tool.input` / `cozeloop.input` / `gen_ai.input` - Tool input +- `gen_ai.tool.output` / `cozeloop.output` / `gen_ai.output` - Tool output + +### Workflow Span Attributes +- `gen_ai.span.kind` - "workflow" +- `gen_ai.operation.name` - "invocation" + +## Configuration + +### YAML Configuration + +Add an `observability` section to your `config.yaml`: + +```yaml +observability: + opentelemetry: + apmplus: + endpoint: "https://apmplus-cn-beijing.volces.com:4318" + api_key: "YOUR_APMPLUS_API_KEY" + service_name: "YOUR_SERVICE_NAME" +``` + +### Environment Variables + +All settings can be overridden via environment variables: + +- `OBSERVABILITY_OPENTELEMETRY_COZELOOP_API_KEY` +- `OBSERVABILITY_OPENTELEMETRY_APMPLUS_API_KEY` +- `OBSERVABILITY_OPENTELEMETRY_ENABLE_GLOBAL_PROVIDER` (default: true) +- `VEADK_USER_ID` - Set default user ID +- `VEADK_SESSION_ID` - Set default session ID +- `VEADK_APP_NAME` - Set default app name +- `VEADK_MODEL_PROVIDER` - Set model provider + + +## Usage + +### Observability Plugin + +To enable automatic trace capture (including the root `invocation` span), register the observability plugin: + +```go +import ( + "github.com/volcengine/veadk-go/observability" + "google.golang.org/adk/cmd/launcher/full" + "google.golang.org/adk/runner" + "google.golang.org/adk/plugin" +) + +func main() { + ctx := context.Background() + + config := &launcher.Config{ + AgentLoader: agent.NewSingleLoader(a), + PluginConfig: runner.PluginConfig{ + Plugins: []*plugin.Plugin{observability.NewPlugin()}, + }, + } + + l := full.NewLauncher() + if err := l.Execute(ctx, config, os.Args[1:]); err != nil { + log.Fatal(err) + } +} +``` + +## Metrics (Aligned with Python ADK) + +This package automatically records standard GenAI metrics when **APMPlus** is configured. The metrics are fully aligned with the Python ADK implementation. + +### Standard GenAI Metrics +- `gen_ai.chat.count`: Counter for number of LLM invocations. +- `gen_ai.client.token.usage`: Histogram for input/output token usage. +- `gen_ai.client.operation.duration`: Histogram for LLM operation latency. +- `gen_ai.chat_completions.exceptions`: Counter for exceptions during chat completions. + +### Streaming Metrics +- `gen_ai.chat_completions.streaming_time_to_first_token`: Time to first token. +- `gen_ai.chat_completions.streaming_time_to_generate`: Total generation time. +- `gen_ai.chat_completions.streaming_time_per_output_token`: Average time per output token. + +### APMPlus Custom Metrics +- `apmplus_span_latency`: Latency for both LLM and Tool spans. +- `apmplus_tool_token_usage`: Estimated token usage for tool inputs (type=input) and outputs (type=output), calculated as `char_len / 4`. + +> **Note**: Metrics collection is automatically enabled when APMPlus configuration involves an API Key. diff --git a/observability/README_zh.md b/observability/README_zh.md new file mode 100644 index 0000000..b7a1d7b --- /dev/null +++ b/observability/README_zh.md @@ -0,0 +1,114 @@ +# VeADK Go 可观测性包 + +本包为 VeADK Go SDK 提供全面的可观测性功能插件,与 [VeADK Python SDK](https://volcengine.github.io/veadk-python/observation/span-attributes/) 和 [OpenTelemetry GenAI 语义约定](https://opentelemetry.io/docs/specs/semconv/gen-ai/) 对齐。 + +## 功能特性 + +- **对齐 Python ADK**:实现了与 Python ADK 相同的 Span 属性、事件结构和命名约定 +- **多平台支持**:支持同时将 Trace 导出到APMPlus、以及本地文件或标准输出 +- **自动属性丰富**:自动从 Context、配置或环境变量中捕获并传播 `SessionID`、`UserID`、`AppName`、`InvocationID` +- **Span 层级支持**:正确跟踪 Invocation → Agent → LLM/Tool 执行的层级关系 +- **指标支持**:自动记录 Token 使用量、操作延迟和首 Token 延迟等指标 + +## Span 属性规范 + +VeADK Go 实现了以下 Span 属性类别,详见 [Python ADK Span 属性文档](https://volcengine.github.io/veadk-python/observation/span-attributes/): + +### 通用属性 (所有 Span) +- `gen_ai.system` - 模型提供商 (例如 "openai", "ark") +- `gen_ai.system.version` - VeADK 版本 +- `gen_ai.agent.name` - Agent 名称 +- `gen_ai.app.name` / `app_name` / `app.name` - 应用名称 +- `gen_ai.user.id` / `user.id` - 用户 ID +- `gen_ai.session.id` / `session.id` - 会话 ID +- `gen_ai.invocation.id` / `invocation.id` - 调用 ID +- `cozeloop.report.source` - 固定值 "veadk" +- `cozeloop.call_type` - CozeLoop 调用类型 +- `openinference.instrumentation.veadk` - 插桩版本 + +### LLM Span 属性 +- `gen_ai.span.kind` - "llm" +- `gen_ai.operation.name` - "chat" +- `gen_ai.request.model` - 模型名称 +- `gen_ai.request.type` - 请求类型 +- `gen_ai.request.max_tokens` - 最大输出 Token 数 +- `gen_ai.request.temperature` - 采样温度 +- `gen_ai.request.top_p` - Top-p 参数 +- `gen_ai.usage.input_tokens` - 输入 Token 数 +- `gen_ai.usage.output_tokens` - 输出 Token 数 +- `gen_ai.usage.total_tokens` - 总 Token 数 +- `gen_ai.prompt` - 输入消息 +- `gen_ai.completion` - 输出消息 +- `gen_ai.messages` - 完整消息事件 +- `gen_ai.choice` - 模型选择 + +### Tool Span 属性 +- `gen_ai.span.kind` - "tool" +- `gen_ai.operation.name` - "execute_tool" +- `gen_ai.tool.name` - 工具名称 +- `gen_ai.tool.input` / `cozeloop.input` / `gen_ai.input` - 工具输入 +- `gen_ai.tool.output` / `cozeloop.output` / `gen_ai.output` - 工具输出 + +### Workflow Span 属性 +- `gen_ai.span.kind` - "workflow" +- `gen_ai.operation.name` - "invocation" + +## 配置 + +### YAML 配置 + +在你的 `config.yaml` 中添加 `observability` 部分: + +```yaml +observability: + opentelemetry: + apmplus: + endpoint: "https://apmplus-cn-beijing.volces.com:4318" + api_key: "YOUR_APMPLUS_API_KEY" + service_name: "YOUR_SERVICE_NAME" +``` + +### 环境变量 + +所有设置均可通过环境变量覆盖: + +- `OBSERVABILITY_OPENTELEMETRY_COZELOOP_API_KEY` +- `OBSERVABILITY_OPENTELEMETRY_APMPLUS_API_KEY` +- `OBSERVABILITY_OPENTELEMETRY_ENABLE_GLOBAL_PROVIDER` (默认: false) +- `VEADK_USER_ID` - 设置默认 User ID +- `VEADK_SESSION_ID` - 设置默认 Session ID +- `VEADK_APP_NAME` - 设置默认 App Name +- `VEADK_MODEL_PROVIDER` - 设置模型提供商 +- `VEADK_CALL_TYPE` - 设置调用类型 + +## 使用方法 + + +### 可观测性插件 + +要启用自动 Trace 捕获(包括根 `invocation` Span),请注册可观测性插件: + +```go +import ( + "github.com/volcengine/veadk-go/observability" + "google.golang.org/adk/cmd/launcher/full" + "google.golang.org/adk/runner" + "google.golang.org/adk/plugin" +) + +func main() { + ctx := context.Background() + + config := &launcher.Config{ + AgentLoader: agent.NewSingleLoader(a), + PluginConfig: runner.PluginConfig{ + Plugins: []*plugin.Plugin{observability.NewPlugin()}, + }, + } + + l := full.NewLauncher() + if err := l.Execute(ctx, config, os.Args[1:]); err != nil { + log.Fatal(err) + } +} +``` diff --git a/observability/attributes.go b/observability/attributes.go new file mode 100644 index 0000000..d503b7b --- /dev/null +++ b/observability/attributes.go @@ -0,0 +1,187 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + "os" + + "github.com/volcengine/veadk-go/configs" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +// setCommonAttributes enriches the span with common attributes from context, config, or env. +func setCommonAttributes(ctx context.Context, span trace.Span) { + // 1. Fixed attributes + span.SetAttributes(attribute.String(AttrCozeloopReportSource, DefaultCozeLoopReportSource)) + + // 2. Dynamic attributes + setDynamicAttribute(span, AttrGenAISystem, GetModelProvider(ctx), FallbackModelProvider) + setDynamicAttribute(span, AttrGenAISystemVersion, Version, "", AttrInstrumentation) + setDynamicAttribute(span, AttrCozeloopCallType, GetCallType(ctx), DefaultCozeLoopCallType) + setDynamicAttribute(span, AttrGenAISessionId, GetSessionId(ctx), FallbackSessionID, AttrSessionId) + setDynamicAttribute(span, AttrGenAIUserId, GetUserId(ctx), FallbackUserID, AttrUserId) + setDynamicAttribute(span, AttrGenAIAppName, GetAppName(ctx), FallbackAppName, AttrAppNameUnderline, AttrAppNameDot) + setDynamicAttribute(span, AttrGenAIAgentName, GetAgentName(ctx), FallbackAgentName, AttrAgentName, AttrAgentNameDot) + setDynamicAttribute(span, AttrGenAIInvocationId, GetInvocationId(ctx), FallbackInvocationID, AttrInvocationId) +} + +// setDynamicAttribute sets an attribute and its aliases if the value is not empty (or falls back to a default). +func setDynamicAttribute(span trace.Span, key string, val string, fallback string, aliases ...string) { + v := val + if v == "" { + v = fallback + } + if v != "" { + span.SetAttributes(attribute.String(key, v)) + for _, alias := range aliases { + span.SetAttributes(attribute.String(alias, v)) + } + } +} + +// setLLMAttributes sets standard GenAI attributes for LLM spans. +func setLLMAttributes(span trace.Span) { + span.SetAttributes( + attribute.String(AttrGenAISpanKind, SpanKindLLM), + attribute.String(AttrGenAIOperationName, "chat"), + ) +} + +// setToolAttributes sets standard GenAI attributes for Tool spans. +func setToolAttributes(span trace.Span, name string) { + span.SetAttributes( + attribute.String(AttrGenAISpanKind, SpanKindTool), + attribute.String(AttrGenAIOperationName, "execute_tool"), + attribute.String(AttrGenAIToolName, name), + ) +} + +// setAgentAttributes sets standard GenAI attributes for Agent spans. +func setAgentAttributes(span trace.Span, name string) { + span.SetAttributes( + attribute.String(AttrGenAIAgentName, name), + attribute.String(AttrAgentName, name), // Alias: agent_name + attribute.String(AttrAgentNameDot, name), // Alias: agent.name + ) +} + +// setWorkflowAttributes sets standard GenAI attributes for Workflow/Root spans. +func setWorkflowAttributes(span trace.Span) { + span.SetAttributes( + attribute.String(AttrGenAISpanKind, SpanKindWorkflow), + attribute.String(AttrGenAIOperationName, "chain"), + ) +} + +func GetUserId(ctx context.Context) string { + return getContextString(ctx, ContextKeyUserId, EnvUserId) +} + +func GetSessionId(ctx context.Context) string { + return getContextString(ctx, ContextKeySessionId, EnvSessionId) +} + +func GetAppName(ctx context.Context) string { + return getContextString(ctx, ContextKeyAppName, EnvAppName) +} + +func GetAgentName(ctx context.Context) string { + return getContextString(ctx, ContextKeyAgentName, EnvAgentName) +} + +func GetCallType(ctx context.Context) string { + return getContextString(ctx, ContextKeyCallType, EnvCallType) +} + +func GetModelProvider(ctx context.Context) string { + return getContextString(ctx, ContextKeyModelProvider, EnvModelProvider) +} + +func GetInvocationId(ctx context.Context) string { + if val, ok := ctx.Value(ContextKeyInvocationId).(string); ok && val != "" { + return val + } + return "" +} + +// getContextString retrieves a string value from Context -> Global Config -> Environment Variable. +func getContextString(ctx context.Context, key contextKey, envVar string) string { + // 1. Try Context + if val, ok := ctx.Value(key).(string); ok && val != "" { + return val + } + + // 2. Try Global Config + if val := getFromGlobalConfig(key); val != "" { + return val + } + + // 3. Fallback to Env Var + return os.Getenv(envVar) +} + +func getFromGlobalConfig(key contextKey) string { + cfg := configs.GetGlobalConfig() + if cfg == nil { + return "" + } + + switch key { + case ContextKeyModelProvider: + if cfg.Model != nil && cfg.Model.Agent != nil { + return cfg.Model.Agent.Provider + } + case ContextKeyAppName: + if ot := cfg.Observability.OpenTelemetry; ot != nil { + if ot.CozeLoop != nil && ot.CozeLoop.ServiceName != "" { + return ot.CozeLoop.ServiceName + } + if ot.ApmPlus != nil && ot.ApmPlus.ServiceName != "" { + return ot.ApmPlus.ServiceName + } + if ot.TLS != nil && ot.TLS.ServiceName != "" { + return ot.TLS.ServiceName + } + } + } + return "" +} + +func getServiceName(cfg *configs.OpenTelemetryConfig) string { + if serviceFromEnv := os.Getenv("OTEL_SERVICE_NAME"); serviceFromEnv != "" { + return serviceFromEnv + } + + if cfg.ApmPlus != nil { + if cfg.ApmPlus.ServiceName != "" { + return cfg.ApmPlus.ServiceName + } + } + + if cfg.CozeLoop != nil { + if cfg.CozeLoop.ServiceName != "" { + return cfg.CozeLoop.ServiceName + } + } + + if cfg.TLS != nil { + if cfg.TLS.ServiceName != "" { + return cfg.TLS.ServiceName + } + } + return "" +} diff --git a/observability/attributes_test.go b/observability/attributes_test.go new file mode 100644 index 0000000..afa5605 --- /dev/null +++ b/observability/attributes_test.go @@ -0,0 +1,79 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +// MockSpan is a minimal mock implementation of trace.Span for testing purposes. +type MockSpan struct { + trace.Span // Embed default NoopSpan to satisfy interface + Attributes map[attribute.Key]attribute.Value +} + +func NewMockSpan() *MockSpan { + return &MockSpan{ + Span: noop.Span{}, + Attributes: make(map[attribute.Key]attribute.Value), + } +} + +func (m *MockSpan) SetAttributes(kv ...attribute.KeyValue) { + for _, a := range kv { + m.Attributes[a.Key] = a.Value + } +} + +func TestEnvFallback(t *testing.T) { + os.Setenv(EnvAppName, "env-app") + defer os.Unsetenv(EnvAppName) + + ctx := context.Background() + assert.Equal(t, "env-app", GetAppName(ctx)) + +} + +func TestSetSpecificAttributes(t *testing.T) { + t.Run("LLM", func(t *testing.T) { + span := NewMockSpan() + setLLMAttributes(span) + assert.Equal(t, SpanKindLLM, span.Attributes[attribute.Key(AttrGenAISpanKind)].AsString()) + assert.Equal(t, "chat", span.Attributes[attribute.Key(AttrGenAIOperationName)].AsString()) + }) + + t.Run("Tool", func(t *testing.T) { + span := NewMockSpan() + setToolAttributes(span, "my-tool") + assert.Equal(t, SpanKindTool, span.Attributes[attribute.Key(AttrGenAISpanKind)].AsString()) + assert.Equal(t, "execute_tool", span.Attributes[attribute.Key(AttrGenAIOperationName)].AsString()) + assert.Equal(t, "my-tool", span.Attributes[attribute.Key(AttrGenAIToolName)].AsString()) + }) + + t.Run("Agent", func(t *testing.T) { + span := NewMockSpan() + setAgentAttributes(span, "my-agent") + assert.Equal(t, "my-agent", span.Attributes[attribute.Key(AttrGenAIAgentName)].AsString()) + assert.Equal(t, "my-agent", span.Attributes[attribute.Key(AttrAgentName)].AsString()) + assert.Equal(t, "my-agent", span.Attributes[attribute.Key(AttrAgentNameDot)].AsString()) + }) +} diff --git a/observability/constant.go b/observability/constant.go new file mode 100644 index 0000000..8ae1bb2 --- /dev/null +++ b/observability/constant.go @@ -0,0 +1,183 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "runtime/debug" +) + +// +// https://volcengine.github.io/veadk-python/observation/span-attributes/ +// + +const ( + InstrumentationName = "github.com/volcengine/veadk-go" +) + +var ( + Version = getVersion() +) + +func getVersion() string { + if info, ok := debug.ReadBuildInfo(); ok { + for _, dep := range info.Deps { + if dep.Path == InstrumentationName && dep.Version != "" { + return dep.Version + } + } + // If linked as main module or not found in deps + if info.Main.Path == InstrumentationName && info.Main.Version != "" { + return info.Main.Version + } + } + return "" +} + +// Span names +const ( + SpanInvocation = "invocation" + SpanInvokeAgent = "invoke_agent" // Will be suffixed with name in code + SpanCallLLM = "call_llm" + SpanExecuteTool = "execute_tool" // Will be suffixed with name in code +) + +// Metric names +const ( + MetricNameChatCount = "gen_ai.chat.count" + MetricNameTokenUsage = "gen_ai.client.token.usage" + MetricNameOperationDuration = "gen_ai.client.operation.duration" + MetricNameExceptions = "gen_ai.chat_completions.exceptions" + MetricNameFirstTokenLatency = "gen_ai.chat_completions.streaming_time_to_first_token" + MetricNameStreamingTimeToGenerate = "gen_ai.chat_completions.streaming_time_to_generate" + MetricNameStreamingTimePerOutputToken = "gen_ai.chat_completions.streaming_time_per_output_token" + + // APMPlus specific metrics + MetricNameAPMPlusSpanLatency = "apmplus_span_latency" + MetricNameAPMPlusToolTokenUsage = "apmplus_tool_token_usage" +) + +// General attributes +const ( + AttrGenAISystem = "gen_ai.system" + AttrGenAISystemVersion = "gen_ai.system.version" + AttrGenAIAgentName = "gen_ai.agent.name" + AttrInstrumentation = "openinference.instrumentation.veadk" + AttrGenAIAppName = "gen_ai.app.name" + AttrGenAIUserId = "gen_ai.user.id" + AttrGenAISessionId = "gen_ai.session.id" + AttrGenAIInvocationId = "gen_ai.invocation.id" + AttrAgentName = "agent_name" // Alias of 'gen_ai.agent.name' for CozeLoop platform + AttrAgentNameDot = "agent.name" // Alias of 'gen_ai.agent.name' for TLS platform + AttrAppNameUnderline = "app_name" // Alias of gen_ai.app.name for CozeLoop platform + AttrAppNameDot = "app.name" // Alias of gen_ai.app.name for TLS platform + AttrUserId = "user.id" // Alias of gen_ai.user.id for CozeLoop/TLS platforms + AttrSessionId = "session.id" // Alias of gen_ai.session.id for CozeLoop/TLS platforms + AttrInvocationId = "invocation.id" // Alias of gen_ai.invocation.id for CozeLoop platform + + AttrErrorType = "error.type" + AttrCozeloopReportSource = "cozeloop.report.source" // Fixed value: veadk + AttrCozeloopCallType = "cozeloop.call_type" // CozeLoop call type + + // Environment Variable Keys for Zero-Config Attributes + EnvModelProvider = "VEADK_MODEL_PROVIDER" + EnvUserId = "VEADK_USER_ID" + EnvSessionId = "VEADK_SESSION_ID" + EnvAppName = "VEADK_APP_NAME" + EnvCallType = "VEADK_CALL_TYPE" + EnvAgentName = "VEADK_AGENT_NAME" + + // Default and fallback values + DefaultCozeLoopCallType = "None" // fixed + DefaultCozeLoopReportSource = "veadk" // fixed + FallbackAgentName = "" + FallbackAppName = "" + FallbackUserID = "" + FallbackSessionID = "" + FallbackModelProvider = "" + FallbackInvocationID = "" + + // Span Kind values (GenAI semantic conventions) + SpanKindWorkflow = "workflow" + SpanKindLLM = "llm" + SpanKindTool = "tool" +) + +// LLM attributes +const ( + AttrGenAIRequestModel = "gen_ai.request.model" + AttrGenAIRequestType = "gen_ai.request.type" + AttrGenAIRequestMaxTokens = "gen_ai.request.max_tokens" + AttrGenAIRequestTemperature = "gen_ai.request.temperature" + AttrGenAIRequestTopP = "gen_ai.request.top_p" + AttrGenAIRequestFunctions = "gen_ai.request.functions" + AttrGenAIResponseModel = "gen_ai.response.model" + AttrGenAIResponseId = "gen_ai.response.id" + AttrGenAIResponseStopReason = "gen_ai.response.stop_reason" + AttrGenAIResponseFinishReason = "gen_ai.response.finish_reason" + AttrGenAIResponseFinishReasons = "gen_ai.response.finish_reasons" + AttrGenAIIsStreaming = "gen_ai.is_streaming" + AttrGenAIPrompt = "gen_ai.prompt" + AttrGenAICompletion = "gen_ai.completion" + AttrGenAIUsageInputTokens = "gen_ai.usage.input_tokens" + AttrGenAIUsageOutputTokens = "gen_ai.usage.output_tokens" + AttrGenAIUsageTotalTokens = "gen_ai.usage.total_tokens" + AttrGenAIUsageCacheCreationInputTokens = "gen_ai.usage.cache_creation_input_tokens" + AttrGenAIUsageCacheReadInputTokens = "gen_ai.usage.cache_read_input_tokens" + AttrGenAIMessages = "gen_ai.messages" + AttrGenAIChoice = "gen_ai.choice" + AttrGenAIResponsePromptTokenCount = "gen_ai.response.prompt_token_count" + AttrGenAIResponseCandidatesTokenCount = "gen_ai.response.candidates_token_count" + AttrGenAITokenType = "gen_ai_token_type" // Metric specific: underscore + + AttrInputValue = "input.value" + AttrOutputValue = "output.value" +) + +// Event names +const ( + EventGenAIContentPrompt = "gen_ai.content.prompt" + EventGenAIContentCompletion = "gen_ai.content.completion" +) + +// Tool attributes +const ( + AttrGenAIOperationName = "gen_ai.operation.name" + AttrGenAIOperationType = "gen_ai.operation.type" + AttrGenAIToolName = "gen_ai.tool.name" + AttrGenAIToolDescription = "gen_ai.tool.description" + AttrGenAIToolInput = "gen_ai.tool.input" + AttrGenAIToolOutput = "gen_ai.tool.output" + AttrGenAISpanKind = "gen_ai.span.kind" + AttrGenAIToolCallID = "gen_ai.tool.call.id" + + // Platform specific + AttrCozeloopInput = "cozeloop.input" + AttrCozeloopOutput = "cozeloop.output" + AttrGenAIInput = "gen_ai.input" + AttrGenAIOutput = "gen_ai.output" +) + +// Context keys for storing runtime values +type contextKey string + +const ( + ContextKeySessionId contextKey = "veadk.session_id" + ContextKeyUserId contextKey = "veadk.user_id" + ContextKeyAppName contextKey = "veadk.app_name" + ContextKeyAgentName contextKey = "veadk.agent_name" + ContextKeyCallType contextKey = "veadk.call_type" + ContextKeyModelProvider contextKey = "veadk.model_provider" + ContextKeyInvocationId contextKey = "veadk.invocation_id" +) diff --git a/observability/exporter.go b/observability/exporter.go new file mode 100644 index 0000000..be9dc0e --- /dev/null +++ b/observability/exporter.go @@ -0,0 +1,323 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/volcengine/veadk-go/configs" + "github.com/volcengine/veadk-go/log" + + "go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/exporters/stdout/stdoutmetric" + "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" + olog "go.opentelemetry.io/otel/sdk/log" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/trace" +) + +const ( + OTELExporterOTLPProtocolEnvKey = "OTEL_EXPORTER_OTLP_PROTOCOL" + OTELExporterOTLPEndpointEnvKey = "OTEL_EXPORTER_OTLP_ENDPOINT" +) + +var ( + fileWriters sync.Map +) + +func createLogClient(ctx context.Context, url, protocol string, headers map[string]string) (olog.Exporter, error) { + if protocol == "" { + protocol = os.Getenv(OTELExporterOTLPProtocolEnvKey) + } + + if url == "" { + return nil, errors.New("OTEL_EXPORTER_OTLP_ENDPOINT is not set") + } + + switch { + case strings.HasPrefix(protocol, "http"): + return otlploghttp.New(ctx, otlploghttp.WithEndpointURL(url), otlploghttp.WithHeaders(headers)) + default: + return otlploggrpc.New(ctx, otlploggrpc.WithEndpointURL(url), otlploggrpc.WithHeaders(headers)) + } + +} + +func createTraceClient(ctx context.Context, url, protocol string, headers map[string]string) (trace.SpanExporter, error) { + if protocol == "" { + protocol = os.Getenv(OTELExporterOTLPProtocolEnvKey) + } + + if url == "" { + url = os.Getenv(OTELExporterOTLPEndpointEnvKey) + } + + switch { + case strings.HasPrefix(protocol, "http"): + return otlptracehttp.New(ctx, otlptracehttp.WithEndpointURL(url), otlptracehttp.WithHeaders(headers)) + default: + return otlptracegrpc.New(ctx, otlptracegrpc.WithEndpointURL(url), otlptracegrpc.WithHeaders(headers)) + } +} + +func createMetricClient(ctx context.Context, url, protocol string, headers map[string]string) (sdkmetric.Exporter, error) { + if protocol == "" { + protocol = os.Getenv(OTELExporterOTLPProtocolEnvKey) + } + + if url == "" { + url = os.Getenv(OTELExporterOTLPEndpointEnvKey) + } + + switch { + case strings.HasPrefix(protocol, "http"): + return otlpmetrichttp.New(ctx, otlpmetrichttp.WithEndpointURL(url), otlpmetrichttp.WithHeaders(headers)) + default: + return otlpmetricgrpc.New(ctx, otlpmetricgrpc.WithEndpointURL(url), otlpmetricgrpc.WithHeaders(headers)) + } +} + +func getFileWriter(path string) io.Writer { + if path == "" { + log.Warn("No path provided for file writer, using io.Discard") + return io.Discard + } + + absPath, err := filepath.Abs(path) + if err != nil { + log.Warn("Failed to resolve absolute path, using original", "path", path, "err", err) + absPath = path + } + + if fileWriter, ok := fileWriters.Load(absPath); ok { + return fileWriter.(io.Writer) + } + + // Ensure directory exists + if dir := filepath.Dir(absPath); dir != "" { + if err := os.MkdirAll(dir, 0755); err != nil { + log.Warn("Failed to create directory for exporter", "path", absPath, "err", err) + } + } + + f, err := os.OpenFile(absPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + log.Warn("Failed to open file for exporter, will use io.Discard instead", "path", absPath, "err", err) + return io.Discard + } + + writers, _ := fileWriters.LoadOrStore(absPath, f) + return writers.(io.Writer) +} + +// NewStdoutExporter creates a simple stdout exporter with pretty printing. +func NewStdoutExporter() (trace.SpanExporter, error) { + return stdouttrace.New(stdouttrace.WithPrettyPrint()) +} + +// NewCozeLoopExporter creates an OTLP HTTP exporter for CozeLoop. +func NewCozeLoopExporter(ctx context.Context, cfg *configs.CozeLoopExporterConfig) (trace.SpanExporter, error) { + endpoint := cfg.Endpoint + return createTraceClient(ctx, endpoint, "", map[string]string{ + "authorization": "Bearer " + cfg.APIKey, + "cozeloop-workspace-id": cfg.ServiceName, + }) +} + +// NewAPMPlusExporter creates an OTLP HTTP exporter for APMPlus. +func NewAPMPlusExporter(ctx context.Context, cfg *configs.ApmPlusConfig) (trace.SpanExporter, error) { + endpoint := cfg.Endpoint + protocol := cfg.Protocol + return createTraceClient(ctx, endpoint, protocol, map[string]string{ + "X-ByteAPM-AppKey": cfg.APIKey, + }) +} + +// NewTLSExporter creates an OTLP HTTP exporter for Volcano TLS. +func NewTLSExporter(ctx context.Context, cfg *configs.TLSExporterConfig) (trace.SpanExporter, error) { + endpoint := cfg.Endpoint + return createTraceClient(ctx, endpoint, "", map[string]string{ + "x-tls-otel-tracetopic": cfg.TopicID, + "x-tls-otel-ak": cfg.AccessKey, + "x-tls-otel-sk": cfg.SecretKey, + "x-tls-otel-region": cfg.Region, + }) +} + +// NewFileExporter creates a span exporter that writes traces to a file. +func NewFileExporter(ctx context.Context, cfg *configs.FileConfig) (trace.SpanExporter, error) { + f := getFileWriter(cfg.Path) + return stdouttrace.New(stdouttrace.WithWriter(f), stdouttrace.WithPrettyPrint()) +} + +// NewMultiExporter creates a span exporter that can export to multiple platforms simultaneously. +func NewMultiExporter(ctx context.Context, cfg *configs.OpenTelemetryConfig) (trace.SpanExporter, error) { + var exporters []trace.SpanExporter + if cfg.Stdout != nil && cfg.Stdout.Enable { + if exp, err := NewStdoutExporter(); err == nil { + exporters = append(exporters, exp) + log.Info("Exporting spans to Stdout") + } + } + + if cfg.File != nil && cfg.File.Path != "" { + if exp, err := NewFileExporter(ctx, cfg.File); err == nil { + exporters = append(exporters, exp) + log.Info(fmt.Sprintf("Exporting spans to File: %s", cfg.File.Path)) + } + } + + if cfg.ApmPlus != nil && cfg.ApmPlus.Endpoint != "" && cfg.ApmPlus.APIKey != "" { + if exp, err := NewAPMPlusExporter(ctx, cfg.ApmPlus); err == nil { + exporters = append(exporters, exp) + log.Info("Exporting spans to APMPlus", "endpoint", cfg.ApmPlus.Endpoint, "service_name", cfg.ApmPlus.ServiceName) + } + } + + if cfg.CozeLoop != nil && cfg.CozeLoop.Endpoint != "" && cfg.CozeLoop.APIKey != "" { + if exp, err := NewCozeLoopExporter(ctx, cfg.CozeLoop); err == nil { + exporters = append(exporters, exp) + log.Info("Exporting spans to CozeLoop", "endpoint", cfg.CozeLoop.Endpoint, "service_name", cfg.CozeLoop.ServiceName) + } + } + + if cfg.TLS != nil && cfg.TLS.Endpoint != "" && cfg.TLS.AccessKey != "" && cfg.TLS.SecretKey != "" { + if exp, err := NewTLSExporter(ctx, cfg.TLS); err == nil { + exporters = append(exporters, exp) + log.Info("Exporting spans to TLS", "endpoint", cfg.TLS.Endpoint, "service_name", cfg.TLS.ServiceName) + } + } + + log.Debug("trace data will be exported", "exporter count", len(exporters)) + + if len(exporters) == 0 { + log.Info("No exporters to export observability data") + } + + if len(exporters) == 1 { + return exporters[0], nil + } + + return &multiExporter{exporters: exporters}, nil +} + +type multiExporter struct { + exporters []trace.SpanExporter +} + +func (m *multiExporter) ExportSpans(ctx context.Context, spans []trace.ReadOnlySpan) error { + var errs []error + for _, e := range m.exporters { + if err := e.ExportSpans(ctx, spans); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +func (m *multiExporter) Shutdown(ctx context.Context) error { + var errs []error + for _, e := range m.exporters { + if err := e.Shutdown(ctx); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +// NewMetricReader creates one or more metric readers based on the provided configuration. +func NewMetricReader(ctx context.Context, cfg *configs.OpenTelemetryConfig) ([]sdkmetric.Reader, error) { + var readers []sdkmetric.Reader + + if cfg.Stdout != nil && cfg.Stdout.Enable { + if exp, err := stdoutmetric.New(); err == nil { + readers = append(readers, sdkmetric.NewPeriodicReader(exp)) + log.Info("Exporting metrics to Stdout") + } + } + + if cfg.File != nil && cfg.File.Path != "" { + if exp, err := NewFileMetricExporter(ctx, cfg.File); err == nil { + readers = append(readers, sdkmetric.NewPeriodicReader(exp)) + log.Info(fmt.Sprintf("Exporting metrics to File: %s", cfg.File.Path)) + } + } + + if cfg.ApmPlus != nil && cfg.ApmPlus.Endpoint != "" && cfg.ApmPlus.APIKey != "" { + if exp, err := NewAPMPlusMetricExporter(ctx, cfg.ApmPlus); err == nil { + readers = append(readers, sdkmetric.NewPeriodicReader(exp)) + log.Info("Exporting metrics to APMPlus", "endpoint", cfg.ApmPlus.Endpoint, "service_name", cfg.ApmPlus.ServiceName) + } + } + + log.Debug("metric data will be exported", "exporter count", len(readers)) + + return readers, nil +} + +// NewCozeLoopMetricExporter creates an OTLP Metric exporter for CozeLoop. +func NewCozeLoopMetricExporter(ctx context.Context, cfg *configs.CozeLoopExporterConfig) (sdkmetric.Exporter, error) { + endpoint := cfg.Endpoint + if endpoint == "" { + return nil, fmt.Errorf("CozeLoop exporter endpoint is required") + } + + return createMetricClient(ctx, endpoint, "", map[string]string{ + "authorization": "Bearer " + cfg.APIKey, + "cozeloop-workspace-id": cfg.ServiceName, + }) +} + +// NewAPMPlusMetricExporter creates an OTLP Metric exporter for APMPlus. +// Supports automatic gRPC (4317) detection. +func NewAPMPlusMetricExporter(ctx context.Context, cfg *configs.ApmPlusConfig) (sdkmetric.Exporter, error) { + endpoint := cfg.Endpoint + protocol := cfg.Protocol + return createMetricClient(ctx, endpoint, protocol, map[string]string{ + "X-ByteAPM-AppKey": cfg.APIKey, + }) + +} + +// NewTLSMetricExporter creates an OTLP Metric exporter for Volcano TLS. +func NewTLSMetricExporter(ctx context.Context, cfg *configs.TLSExporterConfig) (sdkmetric.Exporter, error) { + endpoint := cfg.Endpoint + + return createMetricClient(ctx, endpoint, "", map[string]string{ + "x-tls-otel-tracetopic": cfg.TopicID, + "x-tls-otel-ak": cfg.AccessKey, + "x-tls-otel-sk": cfg.SecretKey, + "x-tls-otel-region": cfg.Region, + }) +} + +// NewFileMetricExporter creates a metric exporter that writes metrics to a file. +func NewFileMetricExporter(ctx context.Context, cfg *configs.FileConfig) (sdkmetric.Exporter, error) { + writer := getFileWriter(cfg.Path) + + return stdoutmetric.New(stdoutmetric.WithWriter(writer), stdoutmetric.WithPrettyPrint()) +} diff --git a/observability/initialize.go b/observability/initialize.go new file mode 100644 index 0000000..81977d9 --- /dev/null +++ b/observability/initialize.go @@ -0,0 +1,263 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + "errors" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/volcengine/veadk-go/configs" + "github.com/volcengine/veadk-go/log" + "google.golang.org/adk/telemetry" + + "go.opentelemetry.io/otel" + sdktrace "go.opentelemetry.io/otel/sdk/trace" +) + +var ( + initConfigOnce sync.Once +) + +// Init initializes the observability system using the global configuration. +// Users usually don't need to call this function directly unless they want to override the default global configuration. +// NewPlugin will call this function to initialize observability once. +func Init(ctx context.Context, cfg *configs.ObservabilityConfig) error { + var err error + var initialized bool + initConfigOnce.Do(func() { + handleSignals(ctx) + + // In veadk-go, config loading might depend on loggers which might depend on global tracer + // or vice versa. We ensure InitConfig is called, and then initialize based on that. + var otelCfg *configs.OpenTelemetryConfig + if cfg != nil { + otelCfg = cfg.OpenTelemetry + } + + if otelCfg == nil { + log.Info("No observability config found, observability data will not be exported") + } + + err = initWithConfig(ctx, otelCfg) + initialized = true + }) + + if initialized { + log.Info("Initializing TraceProvider and MetricsProvider based on observability config") + } + return err +} + +// Shutdown shuts down the observability system, flushing all spans and metrics. +func Shutdown(ctx context.Context) error { + log.Info("Shut down TracerProvider and MeterProvider") + var errs []error + + // 0. End all active root invocation spans to ensure they are recorded and flushed. + // This handles cases like Ctrl+C or premature exit where defer blocks might not run. + GetRegistry().EndAllInvocationSpans() + GetRegistry().Shutdown() + + // 1. Shutdown TracerProvider + tp := otel.GetTracerProvider() + if sdkTP, ok := tp.(*sdktrace.TracerProvider); ok { + if err := sdkTP.ForceFlush(ctx); err != nil { + log.Error("Failed to force flush TracerProvider", "err", err) + errs = append(errs, err) + } + + if err := sdkTP.Shutdown(ctx); err != nil { + log.Error("Failed to shutdown TracerProvider", "err", err) + errs = append(errs, err) + } + } else { + log.Info("Global TracerProvider is not an SDK TracerProvider, skipping shutdown") + } + + // 2. Shutdown local MeterProvider if exists + if localMeterProvider != nil { + if err := localMeterProvider.Shutdown(ctx); err != nil { + errs = append(errs, err) + } + } + + // 3. Shutdown global MeterProvider if exists + if globalMeterProvider != nil { + if err := globalMeterProvider.Shutdown(ctx); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +// initWithConfig automatically initializes the observability system based on the provided configuration. +// It creates the appropriate exporter and calls RegisterExporter. +func initWithConfig(ctx context.Context, cfg *configs.OpenTelemetryConfig) error { + var errs []error + err := initializeTraceProvider(ctx, cfg) + if err != nil { + errs = append(errs, err) + } + + err = initializeMeterProvider(ctx, cfg) + if err != nil { + errs = append(errs, err) + } + + return errors.Join(errs...) +} + +func newVeadkExporter(exp sdktrace.SpanExporter) sdktrace.SpanExporter { + return &VeADKTranslatedExporter{SpanExporter: exp} +} + +// AddSpanExporter registers an exporter to Google ADK's local telemetry. +func AddSpanExporter(exp sdktrace.SpanExporter) { + telemetry.RegisterSpanProcessor(sdktrace.NewBatchSpanProcessor(newVeadkExporter(exp))) +} + +// AddGlobalSpanExporter registers an exporter toglobal TracerProvider. +func AddGlobalSpanExporter(exp sdktrace.SpanExporter) { + globalTP := otel.GetTracerProvider() + if sdkTP, ok := globalTP.(*sdktrace.TracerProvider); ok { + sdkTP.RegisterSpanProcessor(sdktrace.NewBatchSpanProcessor(newVeadkExporter(exp))) + } +} + +// setGlobalTracerProvider configures the global OpenTelemetry TracerProvider. +func setGlobalTracerProvider(exp sdktrace.SpanExporter, spanProcessors ...sdktrace.SpanProcessor) { + // Always wrap with VeADKTranslatedExporter to ensure ADK-internal spans are correctly mapped + translatedExp := newVeadkExporter(exp) + + // Default processors + allProcessors := append([]sdktrace.SpanProcessor{}, spanProcessors...) + + // Use BatchSpanProcessor for all exporters to ensure performance and batching. + finalProcessor := sdktrace.NewBatchSpanProcessor(translatedExp) + + // 1. Try to register with existing TracerProvider if it's an SDK TracerProvider + globalTP := otel.GetTracerProvider() + if sdkTP, ok := globalTP.(*sdktrace.TracerProvider); ok { + log.Info("Registering ADK Processors to existing global TracerProvider") + for _, sp := range allProcessors { + sdkTP.RegisterSpanProcessor(sp) + } + sdkTP.RegisterSpanProcessor(finalProcessor) + return + } + + // 2. Fallback: Create a new global TracerProvider + log.Info("Creating a new global TracerProvider") + var opts []sdktrace.TracerProviderOption + for _, sp := range allProcessors { + opts = append(opts, sdktrace.WithSpanProcessor(sp)) + } + + tp := sdktrace.NewTracerProvider( + append(opts, sdktrace.WithSpanProcessor(finalProcessor))..., + ) + + otel.SetTracerProvider(tp) +} + +func setupLocalTracer(ctx context.Context, cfg *configs.OpenTelemetryConfig) error { + if cfg == nil { + return nil + } + + exp, err := NewMultiExporter(ctx, cfg) + if err != nil { + return err + } + + AddSpanExporter(exp) + return nil +} + +func setupGlobalTracer(ctx context.Context, cfg *configs.OpenTelemetryConfig) error { + log.Info("Registering ADK Global TracerProvider") + + globalExp, err := NewMultiExporter(ctx, cfg) + if err != nil { + return err + } + + if globalExp != nil { + setGlobalTracerProvider(globalExp) + } + return nil +} + +func initializeTraceProvider(ctx context.Context, cfg *configs.OpenTelemetryConfig) error { + var errs []error + if cfg != nil && cfg.EnableLocalProvider { + err := setupLocalTracer(ctx, cfg) + if err != nil { + errs = append(errs, err) + } + } + + if cfg != nil && cfg.EnableGlobalProvider { + err := setupGlobalTracer(ctx, cfg) + if err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +func initializeMeterProvider(ctx context.Context, cfg *configs.OpenTelemetryConfig) error { + var errs []error + if cfg == nil || cfg.EnableMetrics == nil || !*cfg.EnableMetrics { + log.Info("Meter provider is not enabled") + return nil + } + + if cfg.EnableLocalProvider { + readers, err := NewMetricReader(ctx, cfg) + if err != nil { + errs = append(errs, err) + } + registerLocalMetrics(readers) + } + + if cfg.EnableGlobalProvider { + globalReaders, err := NewMetricReader(ctx, cfg) + if err != nil { + errs = append(errs, err) + } + registerGlobalMetrics(globalReaders) + } + return errors.Join(errs...) +} + +// handleSignals registers a signal handler to ensure observability data is flushed on exit. +func handleSignals(ctx context.Context) { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-sigChan + + // Trigger shutdown which will flush all processors (including BatchSpanProcessor) + _ = Shutdown(ctx) + os.Exit(0) + }() +} diff --git a/observability/initialize_test.go b/observability/initialize_test.go new file mode 100644 index 0000000..89b070f --- /dev/null +++ b/observability/initialize_test.go @@ -0,0 +1,107 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/volcengine/veadk-go/configs" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +func TestGetServiceName(t *testing.T) { + t.Run("EnvVar", func(t *testing.T) { + os.Setenv("OTEL_SERVICE_NAME", "env-service") + defer os.Unsetenv("OTEL_SERVICE_NAME") + assert.Equal(t, "env-service", getServiceName(&configs.OpenTelemetryConfig{})) + }) + + t.Run("ApmPlus", func(t *testing.T) { + cfg := &configs.OpenTelemetryConfig{ + ApmPlus: &configs.ApmPlusConfig{ServiceName: "apm-service"}, + } + assert.Equal(t, "apm-service", getServiceName(cfg)) + }) + + t.Run("CozeLoop", func(t *testing.T) { + cfg := &configs.OpenTelemetryConfig{ + CozeLoop: &configs.CozeLoopExporterConfig{ServiceName: "coze-service"}, + } + assert.Equal(t, "coze-service", getServiceName(cfg)) + }) + + t.Run("TLS", func(t *testing.T) { + cfg := &configs.OpenTelemetryConfig{ + TLS: &configs.TLSExporterConfig{ServiceName: "tls-service"}, + } + assert.Equal(t, "tls-service", getServiceName(cfg)) + }) + + t.Run("Unknown", func(t *testing.T) { + assert.Equal(t, "", getServiceName(&configs.OpenTelemetryConfig{})) + }) +} + +func TestSetGlobalTracerProvider(t *testing.T) { + // Save original provider to restore + orig := otel.GetTracerProvider() + defer otel.SetTracerProvider(orig) + + exporter := tracetest.NewInMemoryExporter() + // Just verifies no panic and provider is updated + setGlobalTracerProvider(exporter) + + // Ensure we can start a span + ctx := context.Background() + tr := otel.Tracer("test") + _, span := tr.Start(ctx, "test-span") + span.End() + + // Force flush + if tp, ok := otel.GetTracerProvider().(*trace.TracerProvider); ok { + tp.ForceFlush(ctx) + } + + spans := exporter.GetSpans() + assert.Len(t, spans, 1) +} + +func TestInitializeWithConfig(t *testing.T) { + // Nil config should be fine + err := initWithConfig(context.Background(), nil) + assert.NoError(t, err) + + // Config with disabled global provider but valid exporter + cfg := &configs.OpenTelemetryConfig{ + EnableGlobalProvider: false, + Stdout: &configs.StdoutConfig{Enable: true}, + } + err = initWithConfig(context.Background(), cfg) + assert.NoError(t, err) + + // Config with global provider enabled and stdout + cfgGlobal := &configs.OpenTelemetryConfig{ + EnableGlobalProvider: true, + Stdout: &configs.StdoutConfig{Enable: true}, + } + err = initWithConfig(context.Background(), cfgGlobal) + assert.NoError(t, err) + +} diff --git a/observability/metrics.go b/observability/metrics.go new file mode 100644 index 0000000..573e033 --- /dev/null +++ b/observability/metrics.go @@ -0,0 +1,273 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + + "sync" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" +) + +// Bucket boundaries for histograms, aligned with Python ADK +var ( + // Token usage buckets (count) + genAIClientTokenUsageBuckets = []float64{ + 1, 4, 16, 64, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304, 16777216, 67108864, + } + + // Operation duration buckets (seconds) + genAIClientOperationDurationBuckets = []float64{ + 0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28, 2.56, 5.12, 10.24, 20.48, 40.96, 81.92, + } + + // First token latency buckets (seconds) + genAIServerTimeToFirstTokenBuckets = []float64{ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, + } + + // Time per output token buckets (seconds) + genAIServerTimePerOutputTokenBuckets = []float64{ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5, + } +) + +var ( + // Slices to hold instruments from multiple providers (Global, Local, etc.) + localOnce sync.Once + globalOnce sync.Once + instrumentsMu sync.RWMutex + localMeterProvider *sdkmetric.MeterProvider + globalMeterProvider *sdkmetric.MeterProvider + + // Standard Gen AI Metrics + tokenUsageHistograms []metric.Float64Histogram + operationDurationHistograms []metric.Float64Histogram + chatCountCounters []metric.Int64Counter + exceptionsCounters []metric.Int64Counter + // streaming metrics + streamingTimeToFirstTokenHistograms []metric.Float64Histogram + streamingTimeToGenerateHistograms []metric.Float64Histogram + streamingTimePerOutputTokenHistograms []metric.Float64Histogram + + // special metrics for APMPlus + apmPlusLatencyHistograms []metric.Float64Histogram + apmPlusToolTokenUsageHistograms []metric.Float64Histogram +) + +// registerLocalMetrics initializes the metrics system with a local isolated MeterProvider. +// It does NOT overwrite the global OTel MeterProvider. +func registerLocalMetrics(readers []sdkmetric.Reader) { + localOnce.Do(func() { + options := []sdkmetric.Option{} + for _, r := range readers { + options = append(options, sdkmetric.WithReader(r)) + } + + mp := sdkmetric.NewMeterProvider(options...) + localMeterProvider = mp + initializeInstruments(mp.Meter(InstrumentationName)) + }) +} + +// registerGlobalMetrics configures the global OpenTelemetry MeterProvider with the provided readers. +// This is optional and used when you want unrelated OTel measurements to also be exported. +func registerGlobalMetrics(readers []sdkmetric.Reader) { + globalOnce.Do(func() { + options := []sdkmetric.Option{} + for _, r := range readers { + options = append(options, sdkmetric.WithReader(r)) + } + + mp := sdkmetric.NewMeterProvider(options...) + globalMeterProvider = mp + otel.SetMeterProvider(mp) + // No need to call registerMeter here, because the global proxy registered in init() + initializeInstruments(otel.GetMeterProvider().Meter(InstrumentationName)) + }) +} + +// initializeInstruments initializes the metrics instruments for the provided meter. +// This function is internal and should not be called directly +func initializeInstruments(m metric.Meter) { + instrumentsMu.Lock() + defer instrumentsMu.Unlock() + + // Token usage histogram with bucket boundaries + if h, err := m.Float64Histogram( + MetricNameTokenUsage, + metric.WithDescription("Token consumption of LLM invocations"), + metric.WithUnit("count"), + metric.WithExplicitBucketBoundaries(genAIClientTokenUsageBuckets...), + ); err == nil { + tokenUsageHistograms = append(tokenUsageHistograms, h) + } + + // Operation duration histogram with bucket boundaries + if h, err := m.Float64Histogram( + MetricNameOperationDuration, + metric.WithDescription("GenAI operation duration in seconds"), + metric.WithUnit("s"), + metric.WithExplicitBucketBoundaries(genAIClientOperationDurationBuckets...), + ); err == nil { + operationDurationHistograms = append(operationDurationHistograms, h) + } + + if h, err := m.Float64Histogram( + MetricNameFirstTokenLatency, + metric.WithDescription("Time to first token in streaming responses"), + metric.WithUnit("s"), + metric.WithExplicitBucketBoundaries(genAIServerTimeToFirstTokenBuckets...), + ); err == nil { + streamingTimeToFirstTokenHistograms = append(streamingTimeToFirstTokenHistograms, h) + } + + // Chat count counter + if c, err := m.Int64Counter( + MetricNameChatCount, + metric.WithDescription("Number of chat invocations"), + metric.WithUnit("1"), + ); err == nil { + chatCountCounters = append(chatCountCounters, c) + } + + // Exceptions counter + if c, err := m.Int64Counter( + MetricNameExceptions, + metric.WithDescription("Number of exceptions in chat completions"), + metric.WithUnit("1"), + ); err == nil { + exceptionsCounters = append(exceptionsCounters, c) + } + + // Streaming time to generate histogram + if h, err := m.Float64Histogram( + MetricNameStreamingTimeToGenerate, + metric.WithDescription("Time to generate streaming response"), + metric.WithUnit("s"), + metric.WithExplicitBucketBoundaries(genAIClientOperationDurationBuckets...), + ); err == nil { + streamingTimeToGenerateHistograms = append(streamingTimeToGenerateHistograms, h) + } + + // Streaming time per output token histogram + if h, err := m.Float64Histogram( + MetricNameStreamingTimePerOutputToken, + metric.WithDescription("Time per output token in streaming responses"), + metric.WithUnit("s"), + metric.WithExplicitBucketBoundaries(genAIServerTimePerOutputTokenBuckets...), + ); err == nil { + streamingTimePerOutputTokenHistograms = append(streamingTimePerOutputTokenHistograms, h) + } + + // APMPlus Span Latency + if h, err := m.Float64Histogram( + MetricNameAPMPlusSpanLatency, + metric.WithDescription("APMPlus span latency"), + metric.WithUnit("ms"), // Typically latencies in APM are ms? Standard OTel is seconds. + // User didn't specify unit, but usually latency is time. + // Wait, Python ADK: APMPLUS_SPAN_LATENCY. + // Let's stick to seconds with standard buckets but label it "ApMPlus Span Latency". + // Actually, if it is "Latency", it might be ms in some platforms. + // But Safe choice: Seconds. + metric.WithExplicitBucketBoundaries(genAIClientOperationDurationBuckets...), + ); err == nil { + apmPlusLatencyHistograms = append(apmPlusLatencyHistograms, h) + } + + // APMPlus Tool Token Usage + if h, err := m.Float64Histogram( + MetricNameAPMPlusToolTokenUsage, + metric.WithDescription("Token usage for tools (APMPlus specific)"), + metric.WithUnit("count"), + metric.WithExplicitBucketBoundaries(genAIClientTokenUsageBuckets...), + ); err == nil { + apmPlusToolTokenUsageHistograms = append(apmPlusToolTokenUsageHistograms, h) + } +} + +// RecordTokenUsage records the number of tokens used. +func RecordTokenUsage(ctx context.Context, input, output int64, attrs ...attribute.KeyValue) { + for _, histogram := range tokenUsageHistograms { + if input > 0 { + histogram.Record(ctx, float64(input), metric.WithAttributes( + append(attrs, attribute.String(AttrGenAITokenType, "input"))...)) + } + if output > 0 { + histogram.Record(ctx, float64(output), metric.WithAttributes( + append(attrs, attribute.String(AttrGenAITokenType, "output"))...)) + } + } +} + +// RecordOperationDuration records the duration of an operation. +func RecordOperationDuration(ctx context.Context, durationSeconds float64, attrs ...attribute.KeyValue) { + for _, histogram := range operationDurationHistograms { + histogram.Record(ctx, durationSeconds, metric.WithAttributes(attrs...)) + } +} + +// RecordStreamingTimeToFirstToken records the time to first token in streaming responses. +func RecordStreamingTimeToFirstToken(ctx context.Context, latencySeconds float64, attrs ...attribute.KeyValue) { + for _, histogram := range streamingTimeToFirstTokenHistograms { + histogram.Record(ctx, latencySeconds, metric.WithAttributes(attrs...)) + } +} + +// RecordChatCount records the number of chat invocations. +func RecordChatCount(ctx context.Context, count int64, attrs ...attribute.KeyValue) { + for _, counter := range chatCountCounters { + counter.Add(ctx, count, metric.WithAttributes(attrs...)) + } +} + +// RecordExceptions records the number of exceptions. +func RecordExceptions(ctx context.Context, count int64, attrs ...attribute.KeyValue) { + for _, counter := range exceptionsCounters { + counter.Add(ctx, count, metric.WithAttributes(attrs...)) + } +} + +// RecordStreamingTimeToGenerate records the time to generate. +func RecordStreamingTimeToGenerate(ctx context.Context, durationSeconds float64, attrs ...attribute.KeyValue) { + for _, histogram := range streamingTimeToGenerateHistograms { + histogram.Record(ctx, durationSeconds, metric.WithAttributes(attrs...)) + } +} + +// RecordStreamingTimePerOutputToken records the time per output token. +func RecordStreamingTimePerOutputToken(ctx context.Context, timeSeconds float64, attrs ...attribute.KeyValue) { + for _, histogram := range streamingTimePerOutputTokenHistograms { + histogram.Record(ctx, timeSeconds, metric.WithAttributes(attrs...)) + } +} + +// RecordAPMPlusSpanLatency records the span latency for APMPlus. +func RecordAPMPlusSpanLatency(ctx context.Context, durationSeconds float64, attrs ...attribute.KeyValue) { + for _, histogram := range apmPlusLatencyHistograms { + histogram.Record(ctx, durationSeconds, metric.WithAttributes(attrs...)) + } +} + +// RecordAPMPlusToolTokenUsage records the tool token usage for APMPlus. +func RecordAPMPlusToolTokenUsage(ctx context.Context, tokens int64, attrs ...attribute.KeyValue) { + for _, histogram := range apmPlusToolTokenUsageHistograms { + histogram.Record(ctx, float64(tokens), metric.WithAttributes(attrs...)) + } +} diff --git a/observability/metrics_test.go b/observability/metrics_test.go new file mode 100644 index 0000000..3105244 --- /dev/null +++ b/observability/metrics_test.go @@ -0,0 +1,138 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" +) + +func TestMetricsRecording(t *testing.T) { + // Setup Manual Reader + reader := sdkmetric.NewManualReader() + mp := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + meter := mp.Meter("test-meter") + + // Initialize instruments into the global slice (this appends, which is fine for testing) + initializeInstruments(meter) + + ctx := context.Background() + attrs := []attribute.KeyValue{attribute.String("test.key", "test.val")} + + t.Run("RecordTokenUsage", func(t *testing.T) { + RecordTokenUsage(ctx, 10, 20, attrs...) + + var rm metricdata.ResourceMetrics + err := reader.Collect(ctx, &rm) + assert.NoError(t, err) + + // Find the token usage metric + var foundInput, foundOutput bool + for _, sm := range rm.ScopeMetrics { + for _, m := range sm.Metrics { + if m.Name == MetricNameTokenUsage { + data := m.Data.(metricdata.Histogram[float64]) + for _, dp := range data.DataPoints { + dir, _ := dp.Attributes.Value("gen_ai_token_type") + if dir.AsString() == "input" { + assert.Equal(t, uint64(1), dp.Count) + assert.Equal(t, 10.0, dp.Sum) + foundInput = true + } else if dir.AsString() == "output" { + assert.Equal(t, uint64(1), dp.Count) + assert.Equal(t, 20.0, dp.Sum) + foundOutput = true + } + } + } + } + } + assert.True(t, foundInput, "Input tokens not found") + assert.True(t, foundOutput, "Output tokens not found") + }) + + t.Run("RecordOperationDuration", func(t *testing.T) { + RecordOperationDuration(ctx, 1.5, attrs...) + + var rm metricdata.ResourceMetrics + err := reader.Collect(ctx, &rm) + assert.NoError(t, err) + + var found bool + for _, sm := range rm.ScopeMetrics { + for _, m := range sm.Metrics { + if m.Name == MetricNameOperationDuration { + data := m.Data.(metricdata.Histogram[float64]) + for _, dp := range data.DataPoints { + if dp.Count > 0 { + assert.Equal(t, uint64(1), dp.Count) + assert.Equal(t, 1.5, dp.Sum) + found = true + } + } + } + } + } + assert.True(t, found, "Operation duration not found") + }) + + t.Run("RecordStreamingTimeToFirstToken", func(t *testing.T) { + RecordStreamingTimeToFirstToken(ctx, 0.1, attrs...) + + var rm metricdata.ResourceMetrics + err := reader.Collect(ctx, &rm) + assert.NoError(t, err) + + var found bool + for _, sm := range rm.ScopeMetrics { + for _, m := range sm.Metrics { + if m.Name == MetricNameFirstTokenLatency { + data := m.Data.(metricdata.Histogram[float64]) + for _, dp := range data.DataPoints { + if dp.Count > 0 { + assert.Equal(t, uint64(1), dp.Count) + assert.Equal(t, 0.1, dp.Sum) + found = true + } + } + } + } + } + assert.True(t, found, "Streaming time to first token not found") + }) +} + +func TestRegisterLocalMetrics(t *testing.T) { + // Since registerLocalMetrics uses sync.Once, we can only test it doesn't panic. + // Logic verification is implicitly done via InitializeInstruments testing above. + reader := sdkmetric.NewManualReader() + assert.NotPanics(t, func() { + registerLocalMetrics([]sdkmetric.Reader{reader}) + }) +} + +// We cannot easily test registerGlobalMetrics side effects on otel.GetMeterProvider +// without affecting other tests or global state, but basic execution safety check: +func TestRegisterGlobalMetrics(t *testing.T) { + reader := sdkmetric.NewManualReader() + assert.NotPanics(t, func() { + registerGlobalMetrics([]sdkmetric.Reader{reader}) + }) +} diff --git a/observability/plugin.go b/observability/plugin.go new file mode 100644 index 0000000..9e11b3f --- /dev/null +++ b/observability/plugin.go @@ -0,0 +1,1058 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/volcengine/veadk-go/configs" + "github.com/volcengine/veadk-go/log" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "google.golang.org/adk/agent" + "google.golang.org/adk/model" + "google.golang.org/adk/plugin" + "google.golang.org/adk/session" + "google.golang.org/adk/tool" + "google.golang.org/genai" +) + +// NewPlugin creates a new observability plugin for ADK. +// It returns a *plugin.Plugin that can be registered in launcher.Config or agent.Config. +func NewPlugin(opts ...Option) *plugin.Plugin { + // use global config by default. deep copy to avoid mutating global config. + observabilityConfig := configs.GetGlobalConfig().Observability.Clone() + for _, opt := range opts { + opt(observabilityConfig) + } + + p := &adkObservabilityPlugin{ + config: observabilityConfig, + } + + err := Init(context.Background(), observabilityConfig) + if err != nil { + log.Error("Init observability failed", "error", err) + return nil + } + + p.tracer = otel.Tracer(InstrumentationName) + + // no need to check the error as it is always nil. + pluginInstance, _ := plugin.New(plugin.Config{ + Name: "veadk-observability", + BeforeRunCallback: p.BeforeRun, + AfterRunCallback: p.AfterRun, + BeforeAgentCallback: p.BeforeAgent, + AfterAgentCallback: p.AfterAgent, + BeforeModelCallback: p.BeforeModel, + AfterModelCallback: p.AfterModel, + BeforeToolCallback: p.BeforeTool, + AfterToolCallback: p.AfterTool, + }) + return pluginInstance +} + +// Option defines a functional option for the ADKObservabilityPlugin. +type Option func(config *configs.ObservabilityConfig) + +// WithEnableMetrics creates an Option to manually control metrics recording. +func WithEnableMetrics(enable bool) Option { + return func(cfg *configs.ObservabilityConfig) { + enableVal := enable + cfg.OpenTelemetry.EnableMetrics = &enableVal + } +} + +type adkObservabilityPlugin struct { + config *configs.ObservabilityConfig + + tracer trace.Tracer // global tracer +} + +func (p *adkObservabilityPlugin) isMetricsEnabled() bool { + if p.config == nil || p.config.OpenTelemetry == nil || p.config.OpenTelemetry.EnableMetrics == nil { + return false + } + return *p.config.OpenTelemetry.EnableMetrics +} + +// BeforeRun is called before an agent run starts. +func (p *adkObservabilityPlugin) BeforeRun(ctx agent.InvocationContext) (*genai.Content, error) { + // 1. Start the 'invocation' span + newCtx, span := p.tracer.Start(context.Context(ctx), SpanInvocation, trace.WithSpanKind(trace.SpanKindServer)) + log.Debug("BeforeRun created a new invocation span", "span", span.SpanContext()) + + // Register internal ADK run span ID -> our veadk invocation span context. + adkSpan := trace.SpanFromContext(context.Context(ctx)) + if adkSpan.SpanContext().IsValid() { + GetRegistry().RegisterRunMapping(adkSpan.SpanContext().SpanID(), adkSpan.SpanContext().TraceID(), span.SpanContext(), span) + } + + // 2. Store in state for AfterRun and children + _ = ctx.Session().State().Set(stateKeyInvocationSpan, span) + _ = ctx.Session().State().Set(stateKeyInvocationCtx, newCtx) + + setCommonAttributes(newCtx, span) + setWorkflowAttributes(span) + + // Record start time for metrics + meta := &spanMetadata{ + StartTime: time.Now(), + } + p.storeSpanMetadata(ctx.Session().State(), meta) + + // Capture input from UserContent + if userContent := ctx.UserContent(); userContent != nil { + if jsonIn, err := json.Marshal(userContent); err == nil { + val := string(jsonIn) + span.SetAttributes( + attribute.String(AttrInputValue, val), + attribute.String(AttrGenAIInput, val), + ) + } + } + + return nil, nil +} + +// AfterRun is called after an agent run ends. +func (p *adkObservabilityPlugin) AfterRun(ctx agent.InvocationContext) { + // 1. End the span + if s, _ := ctx.Session().State().Get(stateKeyInvocationSpan); s != nil { + span := s.(trace.Span) + log.Debug("AfterRun get a span from state", "span", span, "isRecording", span.IsRecording()) + + if span.IsRecording() { + // Capture final output if available + if cached, _ := ctx.Session().State().Get(stateKeyStreamingOutput); cached != nil { + if jsonOut, err := json.Marshal(cached); err == nil { + val := string(jsonOut) + span.SetAttributes( + attribute.String(AttrOutputValue, val), + attribute.String(AttrGenAIOutput, val), + ) + } + } + // Capture accumulated token usage for the root invocation span + meta := p.getSpanMetadata(ctx.Session().State()) + + if meta.PromptTokens > 0 { + span.SetAttributes(attribute.Int64(AttrGenAIUsageInputTokens, meta.PromptTokens)) + } + if meta.CandidateTokens > 0 { + span.SetAttributes(attribute.Int64(AttrGenAIUsageOutputTokens, meta.CandidateTokens)) + } + if meta.TotalTokens > 0 { + span.SetAttributes(attribute.Int64(AttrGenAIUsageTotalTokens, meta.TotalTokens)) + } + + // Record final metrics for invocation + if !meta.StartTime.IsZero() { + elapsed := time.Since(meta.StartTime).Seconds() + metricAttrs := []attribute.KeyValue{ + attribute.String("gen_ai_operation_name", "chain"), + attribute.String("gen_ai_operation_type", "workflow"), + attribute.String("gen_ai.system", GetModelProvider(context.Context(ctx))), + } + if p.isMetricsEnabled() { + RecordOperationDuration(context.Background(), elapsed, metricAttrs...) + RecordAPMPlusSpanLatency(context.Background(), elapsed, metricAttrs...) + } + } + + // Clean up from global map with delay to allow children to be exported. + // Since we have multiple exporters, we wait long enough for all of them to finish. + adkSpan := trace.SpanFromContext(context.Context(ctx)) + if adkSpan.SpanContext().IsValid() { + id := adkSpan.SpanContext().SpanID() + tid := adkSpan.SpanContext().TraceID() + veadkTraceID := span.SpanContext().SpanID() + GetRegistry().ScheduleCleanup(tid, id, veadkTraceID) + } + + span.End() + } + } +} + +// BeforeModel is called before the LLM is called. +func (p *adkObservabilityPlugin) BeforeModel(ctx agent.CallbackContext, req *model.LLMRequest) (*model.LLMResponse, error) { + parentCtx := context.Context(ctx) + + if actx, _ := ctx.State().Get(stateKeyInvokeAgentCtx); actx != nil { + parentCtx = actx.(context.Context) + log.Debug("BeforeModel get a parent invoke_agent ctx from state", "parentCtx", parentCtx) + } else if ictx, _ := ctx.State().Get(stateKeyInvocationCtx); ictx != nil { + parentCtx = ictx.(context.Context) + log.Debug("BeforeModel get a parent invocation ctx from state", "parentCtx", parentCtx) + } + + // 2. Start our OWN span to cover the full duration of the call (including streaming). + // ADK's "call_llm" span will be closed prematurely by the framework on the first chunk. + // Align with Python: name is "call_llm" + newCtx, span := p.tracer.Start(parentCtx, SpanCallLLM) + log.Debug("BeforeModel created a span", "span", span.SpanContext(), "is_recording", span.IsRecording()) + _ = ctx.State().Set(stateKeyStreamingSpan, span) + + adkSpan := trace.SpanFromContext(context.Context(ctx)) + if adkSpan.SpanContext().IsValid() { // Register google's ADK span (currently not implemented) -> our veadk span context. + GetRegistry().RegisterLLMMapping(adkSpan.SpanContext().SpanID(), adkSpan.SpanContext().TraceID(), span.SpanContext()) + } + + // Group metadata in a single structure for state storage + meta := p.getSpanMetadata(ctx.State()) + meta.StartTime = time.Now() + meta.PrevPromptTokens = meta.PromptTokens + meta.PrevCandidateTokens = meta.CandidateTokens + meta.PrevTotalTokens = meta.TotalTokens + meta.ModelName = req.Model + p.storeSpanMetadata(ctx.State(), meta) + + // Link back to the ADK internal span if it's there. + // This records the ID of the span started by the ADK framework, which we + // often bypass to maintain a cleaner hierarchy in our veadk spans. + adkSpan = trace.SpanFromContext(context.Context(ctx)) + if adkSpan.SpanContext().IsValid() { + span.SetAttributes(attribute.String("adk.internal_span_id", adkSpan.SpanContext().SpanID().String())) + } + + setCommonAttributes(newCtx, span) + // Set GenAI standard span attributes + setLLMAttributes(span) + + // Record request attributes + p.setLLMRequestAttributes(ctx, span, req) + + // Capture messages in GenAI format for the span + messages := p.extractMessages(req) + var msgAttrs []attribute.KeyValue + messagesJSON, err := json.Marshal(messages) + if err == nil { + msgAttrs = append(msgAttrs, attribute.String(AttrGenAIMessages, string(messagesJSON))) + } + + // Flatten messages for gen_ai.prompt.[n] attributes (alignment with python) + msgAttrs = append(msgAttrs, p.flattenPrompt(messages)...) + + // Add input.value (standard for some collectors) + msgAttrs = append(msgAttrs, attribute.String(AttrGenAIInput, string(messagesJSON))) + + msgAttrs = append(msgAttrs, attribute.String(AttrInputValue, string(messagesJSON))) + + span.SetAttributes(msgAttrs...) + + // Add gen_ai.messages events (system, user, tool, assistant) aligned with Python + p.addMessageEvents(span, ctx, req) + + // Add gen_ai.content.prompt event (OTEL GenAI convention) + span.AddEvent(EventGenAIContentPrompt, trace.WithAttributes( + attribute.String(AttrGenAIPrompt, string(messagesJSON)), + attribute.String(AttrGenAIInput, string(messagesJSON)), + )) + + return nil, nil +} + +func (p *adkObservabilityPlugin) setLLMRequestAttributes(ctx agent.CallbackContext, span trace.Span, req *model.LLMRequest) { + attrs := []attribute.KeyValue{ + attribute.String(AttrGenAIRequestModel, req.Model), + attribute.String(AttrGenAIRequestType, "chat"), // Default to chat + attribute.String(AttrGenAISystem, GetModelProvider(context.Context(ctx))), + } + + if req.Config != nil { + if req.Config.Temperature != nil { + attrs = append(attrs, attribute.Float64(AttrGenAIRequestTemperature, float64(*req.Config.Temperature))) + } + if req.Config.TopP != nil { + attrs = append(attrs, attribute.Float64(AttrGenAIRequestTopP, float64(*req.Config.TopP))) + } + if req.Config.MaxOutputTokens > 0 { + attrs = append(attrs, attribute.Int64(AttrGenAIRequestMaxTokens, int64(req.Config.MaxOutputTokens))) + } + + funcIdx := 0 + for _, tool := range req.Config.Tools { + if tool.FunctionDeclarations != nil { + for _, fn := range tool.FunctionDeclarations { + prefix := fmt.Sprintf("gen_ai.request.functions.%d.", funcIdx) // Simplified indexing + attrs = append(attrs, attribute.String(prefix+"name", fn.Name)) + attrs = append(attrs, attribute.String(prefix+"description", fn.Description)) + if fn.Parameters != nil { + paramsJSON, _ := json.Marshal(fn.Parameters) + attrs = append(attrs, attribute.String(prefix+"parameters", string(paramsJSON))) + } + funcIdx++ + } + } + } + } + span.SetAttributes(attrs...) +} + +// AfterModel is called after the LLM returns. +func (p *adkObservabilityPlugin) AfterModel(ctx agent.CallbackContext, resp *model.LLMResponse, err error) (*model.LLMResponse, error) { + // 1. Get our managed span from state + s, _ := ctx.State().Get(stateKeyStreamingSpan) + if s == nil { + log.Warn("AfterModel: No streaming span found in state") + return nil, nil + } + span := s.(trace.Span) + // log.Debug("AfterModel get a trace span from state", "span", span.SpanContext(), "type", fmt.Sprintf("%T", s), "is_recording", span.IsRecording()) + + // 2. Wrap the cleanup to ensure span is always ended on error or final chunk. + // ADK calls AfterModel for EVERY chunk in a stream. + // resp.Partial is true for intermediate chunks, false for the final one. + defer func() { + if err != nil || (resp != nil && !resp.Partial) { + if span.IsRecording() { + log.Debug("AfterModel got a partial response", "span", span.SpanContext()) + span.End() + } + } + }() + + if err != nil { + span.SetStatus(codes.Error, err.Error()) + // Record Exceptions metric + if p.isMetricsEnabled() { + meta := p.getSpanMetadata(ctx.State()) + metricAttrs := []attribute.KeyValue{ + attribute.String(AttrGenAISystem, GetModelProvider(context.Context(ctx))), + attribute.String("gen_ai_response_model", meta.ModelName), + attribute.String("gen_ai_operation_name", "chat"), + attribute.String("gen_ai_operation_type", "llm"), + attribute.String("error_type", "error"), // Simple error type + } + RecordExceptions(context.Context(ctx), 1, metricAttrs...) + } + return nil, nil + } + + if resp == nil { + return nil, nil + } + + if !span.IsRecording() { + log.Warn("AfterModel: span is not recording", "span", span) + // Even if not recording, we should still accumulate content for metrics/logs + } + + // Record responding model + meta := p.getSpanMetadata(ctx.State()) + // Try to get confirmation from response metadata first (passed from sdk) + var finalModelName string + if resp.CustomMetadata != nil { + if m, ok := resp.CustomMetadata["response_model"].(string); ok { + finalModelName = m + } + } + // Fallback to request model name if not present in response + if finalModelName == "" { + finalModelName = meta.ModelName + } + if finalModelName != "" { + span.SetAttributes(attribute.String(AttrGenAIResponseModel, finalModelName)) + } + + if resp.UsageMetadata != nil { + p.handleUsage(ctx, span, resp, resp.Partial, finalModelName) + } + + // Capture tool calls from response to link future tool spans + if resp.Content != nil { + adkSpan := trace.SpanFromContext(context.Context(ctx)) + adkTraceID := trace.TraceID{} + if adkSpan.SpanContext().IsValid() { + adkTraceID = adkSpan.SpanContext().TraceID() + } + + for _, part := range resp.Content.Parts { + if part.FunctionCall != nil && part.FunctionCall.ID != "" { + log.Debug(" AfterModel, registering ToolCallID mapping", "tool_call_id", part.FunctionCall.ID, "parent_llm_span_id", span.SpanContext()) + GetRegistry().RegisterToolCallMapping(part.FunctionCall.ID, adkTraceID, span.SpanContext()) + } + } + } + + if resp.FinishReason != "" { + span.SetAttributes(attribute.String(AttrGenAIResponseFinishReason, string(resp.FinishReason))) + } + + // Record response content + var currentAcc *genai.Content + cached, _ := ctx.State().Get(stateKeyStreamingOutput) + if cached != nil { + currentAcc = cached.(*genai.Content) + } + + // --------------------------------------------------------- + // Metrics: Time to First Token (Streaming Only) + // --------------------------------------------------------- + p.recordTimeToFirstToken(ctx, resp, meta, currentAcc, finalModelName) + + if resp.Content != nil { + currentAcc = p.processStreamingChunk(ctx, resp, currentAcc) + } + + // For streaming, we update the span attributes with what we have so far + var fullText string + if currentAcc != nil { + fullText = p.updateStreamingSpanAttributes(span, currentAcc) + } + + // Metrics: Time to Generate (Streaming Only) & Time Per Output Token + p.recordStreamingGenerationMetrics(ctx, resp, meta, currentAcc, finalModelName) + + // If this is the final chunk, add the completion event + if !resp.Partial && currentAcc != nil { + contentJSON, _ := json.Marshal(currentAcc) + span.AddEvent(EventGenAIContentCompletion, trace.WithAttributes( + attribute.String(AttrGenAICompletion, string(contentJSON)), + attribute.String(AttrGenAIOutput, fullText), + )) + + // Add gen_ai.choice event (aligned with Python) + p.addChoiceEvents(span, currentAcc) + } + + if !resp.Partial { + // Record Operation Duration and Latency + p.recordFinalResponseMetrics(ctx, meta, finalModelName) + } + + return nil, nil +} + +func (p *adkObservabilityPlugin) recordTimeToFirstToken(ctx agent.CallbackContext, resp *model.LLMResponse, meta *spanMetadata, currentAcc *genai.Content, finalModelName string) { + if resp.Partial && currentAcc == nil && resp.Content != nil { + // This is the very first chunk + if !meta.StartTime.IsZero() { + meta.FirstTokenTime = time.Now() + p.storeSpanMetadata(ctx.State(), meta) + + if p.isMetricsEnabled() { + // Record streaming time to first token + latency := time.Since(meta.StartTime).Seconds() + metricAttrs := []attribute.KeyValue{ + attribute.String(AttrGenAISystem, GetModelProvider(context.Context(ctx))), + attribute.String("gen_ai_response_model", finalModelName), + attribute.String("gen_ai_operation_name", "chat"), + attribute.String("gen_ai_operation_type", "llm"), + } + RecordStreamingTimeToFirstToken(context.Context(ctx), latency, metricAttrs...) + } + } else { + log.Warn("didn't find the start time of span", "meta", meta) + } + } +} + +func (p *adkObservabilityPlugin) processStreamingChunk(ctx agent.CallbackContext, resp *model.LLMResponse, currentAcc *genai.Content) *genai.Content { + if currentAcc == nil { + currentAcc = &genai.Content{Role: resp.Content.Role} + if currentAcc.Role == "" { + currentAcc.Role = "model" + } + } + + // If this is the final response, our implementation (like OpenAI) often sends the full content. + // We clear our previous accumulation to avoid duplication in the span attributes. + // We only do this if the final response actually contains content. + if !resp.Partial && resp.Content != nil && len(resp.Content.Parts) > 0 { + currentAcc.Parts = nil + } + + // Accumulate parts with merging of adjacent text + for _, part := range resp.Content.Parts { + // If it's a text part, try to merge with the last part if that was also text + if part.Text != "" && len(currentAcc.Parts) > 0 { + lastPart := currentAcc.Parts[len(currentAcc.Parts)-1] + if lastPart.Text != "" && lastPart.FunctionCall == nil && lastPart.FunctionResponse == nil && lastPart.InlineData == nil { + lastPart.Text += part.Text + continue + } + } + + // Otherwise append as a new part + newPart := &genai.Part{} + *newPart = *part + currentAcc.Parts = append(currentAcc.Parts, newPart) + } + _ = ctx.State().Set(stateKeyStreamingOutput, currentAcc) + return currentAcc +} + +func (p *adkObservabilityPlugin) updateStreamingSpanAttributes(span trace.Span, currentAcc *genai.Content) string { + // Set output.value to the cumulative text (parity with python) + var textParts strings.Builder + textParts.Grow(len(currentAcc.Parts) * 4) + for _, p := range currentAcc.Parts { + if p.Text != "" { + textParts.WriteString(p.Text) + } + } + fullText := textParts.String() + span.SetAttributes(attribute.String(AttrGenAIOutput, fullText)) + + // Add output.value for full JSON representation + if contentJSON, err := json.Marshal(currentAcc); err == nil { + span.SetAttributes(attribute.String("output.value", string(contentJSON))) + } + + // Also set the structured GenAI attributes + span.SetAttributes(p.flattenCompletion(currentAcc)...) + return fullText +} + +func (p *adkObservabilityPlugin) recordStreamingGenerationMetrics(ctx agent.CallbackContext, resp *model.LLMResponse, meta *spanMetadata, currentAcc *genai.Content, finalModelName string) { + if !resp.Partial && currentAcc != nil { + if !meta.StartTime.IsZero() { + // Time Per Output Token + // Only valid if we have output tokens and we tracked first token time + if p.isMetricsEnabled() { + if meta.CandidateTokens > 0 { + generateDuration := time.Since(meta.StartTime).Seconds() + metricAttrs := []attribute.KeyValue{ + attribute.String(AttrGenAISystem, GetModelProvider(context.Context(ctx))), + attribute.String("gen_ai_response_model", finalModelName), + attribute.String("gen_ai_operation_name", "chat"), + attribute.String("gen_ai_operation_type", "llm"), + } + RecordStreamingTimeToGenerate(context.Context(ctx), generateDuration, metricAttrs...) + + if generateDuration > 0 { + timePerToken := generateDuration / float64(meta.CandidateTokens) + RecordStreamingTimePerOutputToken(context.Context(ctx), timePerToken, metricAttrs...) + } + } + } + } + } +} + +func (p *adkObservabilityPlugin) recordFinalResponseMetrics(ctx agent.CallbackContext, meta *spanMetadata, finalModelName string) { + if !meta.StartTime.IsZero() { + duration := time.Since(meta.StartTime).Seconds() + metricAttrs := []attribute.KeyValue{ + attribute.String(AttrGenAISystem, GetModelProvider(context.Context(ctx))), + attribute.String("gen_ai_response_model", finalModelName), + attribute.String("gen_ai_operation_name", "chat"), + attribute.String("gen_ai_operation_type", "llm"), + } + if p.isMetricsEnabled() { + RecordOperationDuration(context.Context(ctx), duration, metricAttrs...) + RecordAPMPlusSpanLatency(context.Context(ctx), duration, metricAttrs...) + } + } +} + +func (p *adkObservabilityPlugin) handleUsage(ctx agent.CallbackContext, span trace.Span, resp *model.LLMResponse, isStream bool, modelName string) { + meta := p.getSpanMetadata(ctx.State()) + + // 1. Get current call usage + currentPrompt := int64(resp.UsageMetadata.PromptTokenCount) + currentCandidate := int64(resp.UsageMetadata.CandidatesTokenCount) + currentTotal := int64(resp.UsageMetadata.TotalTokenCount) + + if currentTotal == 0 && (currentPrompt > 0 || currentCandidate > 0) { + currentTotal = currentPrompt + currentCandidate + } + + // 2. New session total = previous calls total + current call's (latest) usage + // (Note: in streaming, currentCall usage is cumulative for this call) + meta.PromptTokens = meta.PrevPromptTokens + currentPrompt + meta.CandidateTokens = meta.PrevCandidateTokens + currentCandidate + meta.TotalTokens = meta.PrevTotalTokens + currentTotal + + // 3. Update session-wide totals + p.storeSpanMetadata(ctx.State(), meta) + + // 4. Set attributes on the current LLM span (only current call's usage) + attrs := make([]attribute.KeyValue, 0, 7) + if currentPrompt > 0 { + attrs = append(attrs, attribute.Int64(AttrGenAIUsageInputTokens, currentPrompt)) + attrs = append(attrs, attribute.Int64(AttrGenAIResponsePromptTokenCount, currentPrompt)) + } + if currentCandidate > 0 { + attrs = append(attrs, attribute.Int64(AttrGenAIUsageOutputTokens, currentCandidate)) + attrs = append(attrs, attribute.Int64(AttrGenAIResponseCandidatesTokenCount, currentCandidate)) + } + if currentTotal > 0 { + attrs = append(attrs, attribute.Int64(AttrGenAIUsageTotalTokens, currentTotal)) + } + + if resp.UsageMetadata != nil { + if resp.UsageMetadata.CachedContentTokenCount > 0 { + attrs = append(attrs, attribute.Int64(AttrGenAIUsageCacheReadInputTokens, int64(resp.UsageMetadata.CachedContentTokenCount))) + } + // Always set cache creation to 0 if not provided, for parity with python + attrs = append(attrs, attribute.Int64(AttrGenAIUsageCacheCreationInputTokens, 0)) + } + + span.SetAttributes(attrs...) + + // Record metrics directly from the plugin logic + if p.isMetricsEnabled() { + metricAttrs := []attribute.KeyValue{ + attribute.String(AttrGenAISystem, GetModelProvider(ctx)), + attribute.String("gen_ai_response_model", modelName), + attribute.String("gen_ai_operation_name", "chat"), + attribute.String("gen_ai_operation_type", "llm"), + } + RecordChatCount(context.Context(ctx), 1, metricAttrs...) + + if currentTotal > 0 && (currentPrompt > 0 || currentCandidate > 0) { + RecordTokenUsage(context.Context(ctx), currentPrompt, currentCandidate, metricAttrs...) + + } + } +} + +func (p *adkObservabilityPlugin) addMessageEvents(span trace.Span, ctx agent.CallbackContext, req *model.LLMRequest) { + // 1. System Message + if req.Config != nil && req.Config.SystemInstruction != nil { + sysContent := "" + for _, part := range req.Config.SystemInstruction.Parts { + if part.Text != "" { + sysContent += part.Text + } + } + if sysContent != "" { + span.AddEvent("gen_ai.system.message", trace.WithAttributes( + attribute.String("role", "system"), + attribute.String("content", sysContent), + )) + } + } + + // 2. User, Tool, Assistant Messages from History + for _, content := range req.Contents { + if content.Role == "user" { + userEventAttrs := []attribute.KeyValue{ + attribute.String("role", "user"), + } + + // Check if it's a tool response (which comes in as 'user' role in Gemini/ADK but logically is tool message) + // Actually ADK structure: + // User inputs -> Role: user + // Tool Outputs -> Role: user (FunctionResponse) or "tool" depending on model? + // Python implementation checks `part.function_response`. + + hasToolResponse := false + for _, part := range content.Parts { + if part.FunctionResponse != nil { + hasToolResponse = true + // Emit separate event for each tool response + span.AddEvent("gen_ai.tool.message", trace.WithAttributes( + attribute.String("role", "tool"), + attribute.String("id", part.FunctionResponse.ID), + attribute.String("content", safeMarshal(part.FunctionResponse.Response)), + )) + } + } + + if hasToolResponse { + continue + } + + // Normal User Message + for i, part := range content.Parts { + if part.Text != "" { + if len(content.Parts) == 1 { + userEventAttrs = append(userEventAttrs, attribute.String("content", sanitizeUTF8(part.Text))) + } else { + userEventAttrs = append(userEventAttrs, attribute.String("parts."+strconv.Itoa(i)+".type", "text")) + userEventAttrs = append(userEventAttrs, attribute.String("parts."+strconv.Itoa(i)+".text", sanitizeUTF8(part.Text))) + } + } + if part.InlineData != nil && len(part.InlineData.Data) > 0 { + // Image/Blob handling + prefix := "parts." + strconv.Itoa(i) + if len(content.Parts) == 1 { + prefix = "parts.0" + } + userEventAttrs = append(userEventAttrs, attribute.String(prefix+".type", "image_url")) + // MIME type or display name mapping + userEventAttrs = append(userEventAttrs, attribute.String(prefix+".image_url.url", part.InlineData.MIMEType)) + } + } + span.AddEvent("gen_ai.user.message", trace.WithAttributes(userEventAttrs...)) + + } else if content.Role == "model" { + assistantEventAttrs := []attribute.KeyValue{ + attribute.String("role", "assistant"), + } + for i, part := range content.Parts { + if part.Text != "" { + assistantEventAttrs = append(assistantEventAttrs, attribute.String("parts."+strconv.Itoa(i)+".type", "text")) + assistantEventAttrs = append(assistantEventAttrs, attribute.String("parts."+strconv.Itoa(i)+".text", sanitizeUTF8(part.Text))) + } + if part.FunctionCall != nil { + // Tool Calls + prefix := "tool_calls.0" // Assuming single tool call per part or simplifying + assistantEventAttrs = append(assistantEventAttrs, attribute.String(prefix+".id", part.FunctionCall.ID)) + assistantEventAttrs = append(assistantEventAttrs, attribute.String(prefix+".type", "function")) + assistantEventAttrs = append(assistantEventAttrs, attribute.String(prefix+".function.name", part.FunctionCall.Name)) + assistantEventAttrs = append(assistantEventAttrs, attribute.String(prefix+".function.arguments", safeMarshal(part.FunctionCall.Args))) + } + } + span.AddEvent("gen_ai.assistant.message", trace.WithAttributes(assistantEventAttrs...)) + } + } +} + +func (p *adkObservabilityPlugin) addChoiceEvents(span trace.Span, content *genai.Content) { + for i, part := range content.Parts { + attrs := make([]attribute.KeyValue, 0, 2) + if part.Text != "" { + attrs = append(attrs, attribute.String("message.parts."+strconv.Itoa(i)+".type", "text")) + attrs = append(attrs, attribute.String("message.parts."+strconv.Itoa(i)+".text", sanitizeUTF8(part.Text))) + } + if len(attrs) > 0 { + span.AddEvent("gen_ai.choice", trace.WithAttributes(attrs...)) + } + } +} + +// extractMessages converts ADK model.LLMRequest contents into a JSON-compatible message list. +func (p *adkObservabilityPlugin) extractMessages(req *model.LLMRequest) []map[string]any { + var messages []map[string]any + for _, content := range req.Contents { + role := content.Role + if role == "model" { + role = "assistant" + } + + msg := map[string]any{ + "role": role, + } + + var textParts []string + var toolCalls []map[string]any + var toolResponses []map[string]any + + for _, part := range content.Parts { + if part.Text != "" { + textParts = append(textParts, sanitizeUTF8(part.Text)) + } + if part.FunctionCall != nil { + toolCalls = append(toolCalls, map[string]any{ + "id": part.FunctionCall.ID, + "type": "function", + "function": map[string]any{ + "name": part.FunctionCall.Name, + "arguments": safeMarshal(part.FunctionCall.Args), + }, + }) + } + if part.FunctionResponse != nil { + toolResponses = append(toolResponses, map[string]any{ + "id": part.FunctionResponse.ID, + "name": part.FunctionResponse.Name, + "content": safeMarshal(part.FunctionResponse.Response), + }) + } + } + + if len(textParts) > 0 { + msg["content"] = strings.Join(textParts, "") + } + if len(toolCalls) > 0 { + msg["tool_calls"] = toolCalls + } + if len(toolResponses) > 0 { + // In standard GenAI, tool responses are often represented separate messages or differently. + // Alignment with veadk-python usually means following their structure. + msg["tool_responses"] = toolResponses + } + + messages = append(messages, msg) + } + return messages +} + +func (p *adkObservabilityPlugin) flattenPrompt(messages []map[string]any) []attribute.KeyValue { + var attrs []attribute.KeyValue + idx := 0 + for _, msg := range messages { + // In Python, each piece of content/part increments the index. + // Since we already merged text parts in extractMessages, we just process each message here. + // If we wanted exact parity for multi-part messages, we'd need to change extractMessages. + // For now, this is a good approximation that matches the role/content flat structure. + prefix := "gen_ai.prompt." + strconv.Itoa(idx) + if role, ok := msg["role"].(string); ok { + attrs = append(attrs, attribute.String(prefix+".role", role)) + } + if content, ok := msg["content"].(string); ok { + attrs = append(attrs, attribute.String(prefix+".content", content)) + } + + if toolCalls, ok := msg["tool_calls"].([]map[string]any); ok { + for j, tc := range toolCalls { + tcPrefix := prefix + ".tool_calls." + strconv.Itoa(j) + if id, ok := tc["id"].(string); ok { + attrs = append(attrs, attribute.String(tcPrefix+".id", id)) + } + if t, ok := tc["type"].(string); ok { + attrs = append(attrs, attribute.String(tcPrefix+".type", t)) + } + if fn, ok := tc["function"].(map[string]any); ok { + if name, ok := fn["name"].(string); ok { + attrs = append(attrs, attribute.String(tcPrefix+".function.name", name)) + } + if args, ok := fn["arguments"].(string); ok { + attrs = append(attrs, attribute.String(tcPrefix+".function.arguments", args)) + } + } + } + } + + if toolResponses, ok := msg["tool_responses"].([]map[string]any); ok { + for j, tr := range toolResponses { + trPrefix := prefix + ".tool_responses." + strconv.Itoa(j) + if id, ok := tr["id"].(string); ok { + attrs = append(attrs, attribute.String(trPrefix+".id", id)) + } + if name, ok := tr["name"].(string); ok { + attrs = append(attrs, attribute.String(trPrefix+".name", name)) + } + if content, ok := tr["content"].(string); ok { + attrs = append(attrs, attribute.String(trPrefix+".content", content)) + } + } + } + idx++ + } + return attrs +} + +func (p *adkObservabilityPlugin) flattenCompletion(content *genai.Content) []attribute.KeyValue { + var attrs []attribute.KeyValue + + role := content.Role + if role == "model" { + role = "assistant" + } + + for idx, part := range content.Parts { + prefix := "gen_ai.completion." + strconv.Itoa(idx) + attrs = append(attrs, attribute.String(prefix+".role", role)) + + if part.Text != "" { + attrs = append(attrs, attribute.String(prefix+".content", sanitizeUTF8(part.Text))) + } + if part.FunctionCall != nil { + tcPrefix := prefix + ".tool_calls.0" + attrs = append(attrs, attribute.String(tcPrefix+".id", part.FunctionCall.ID)) + attrs = append(attrs, attribute.String(tcPrefix+".type", "function")) + attrs = append(attrs, attribute.String(tcPrefix+".function.name", part.FunctionCall.Name)) + attrs = append(attrs, attribute.String(tcPrefix+".function.arguments", safeMarshal(part.FunctionCall.Args))) + } + } + + return attrs +} + +// BeforeTool is called before a tool is executed. +func (p *adkObservabilityPlugin) BeforeTool(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { + // Note: In Google ADK-go, the execute_tool span is often not available in the context at this stage. + // We rely on VeADKTranslatedExporter (translator.go) to reconstruct tool attributes from the + // span after it is ended and exported. + + // Maintain metadata for metrics calculation + meta := p.getSpanMetadata(ctx.State()) + meta.StartTime = time.Now() + p.storeSpanMetadata(ctx.State(), meta) + return nil, nil +} + +// AfterTool is called after a tool is executed. +func (p *adkObservabilityPlugin) AfterTool(ctx tool.Context, tool tool.Tool, args, result map[string]any, err error) (map[string]any, error) { + // Metrics recording only + meta := p.getSpanMetadata(ctx.State()) + if !meta.StartTime.IsZero() { + duration := time.Since(meta.StartTime).Seconds() + metricAttrs := []attribute.KeyValue{ + attribute.String("gen_ai_operation_name", tool.Name()), + attribute.String("gen_ai_operation_type", "tool"), + attribute.String(AttrGenAISystem, GetModelProvider(context.Context(ctx))), + } + if p.isMetricsEnabled() { + RecordOperationDuration(context.Background(), duration, metricAttrs...) + RecordAPMPlusSpanLatency(context.Background(), duration, metricAttrs...) + } + + if p.isMetricsEnabled() { + // Tool Token Usage (Estimated) + + // Input Chars + var inputChars int64 + if argsJSON, err := json.Marshal(args); err == nil { + inputChars = int64(len(argsJSON)) + } + + // Output Chars + var outputChars int64 + if resultJSON, err := json.Marshal(result); err == nil { + outputChars = int64(len(resultJSON)) + } + + if inputChars > 0 { + RecordAPMPlusToolTokenUsage(context.Background(), inputChars/4, append(metricAttrs, attribute.String("token_type", "input"))...) + } + if outputChars > 0 { + RecordAPMPlusToolTokenUsage(context.Background(), outputChars/4, append(metricAttrs, attribute.String("token_type", "output"))...) + } + } + } + + return nil, nil +} + +// BeforeAgent is called before an agent execution. +func (p *adkObservabilityPlugin) BeforeAgent(ctx agent.CallbackContext) (*genai.Content, error) { + agentName := ctx.AgentName() + if agentName == "" { + agentName = FallbackAgentName + } + + // 1. Get the parent context from state to maintain hierarchy + parentCtx := context.Context(ctx) + if ictx, _ := ctx.State().Get(stateKeyInvocationCtx); ictx != nil { + parentCtx = ictx.(context.Context) + } + + // 2. Start the 'invoke_agent' span manually. + // Since we can't easily wrap the Agent interface due to internal methods, + // we use the plugin to start our span. + spanName := SpanInvokeAgent + " " + agentName + newCtx, span := p.tracer.Start(parentCtx, spanName) + + // Register internal ADK's agent span ID -> our veadk agent span context. + adkSpan := trace.SpanFromContext(context.Context(ctx)) + if adkSpan.SpanContext().IsValid() { + GetRegistry().RegisterAgentMapping(adkSpan.SpanContext().SpanID(), adkSpan.SpanContext().TraceID(), span.SpanContext()) + } + + // 3. Store in state for AfterAgent and children + _ = ctx.State().Set(stateKeyInvokeAgentSpan, span) + _ = ctx.State().Set(stateKeyInvokeAgentCtx, newCtx) + + // 4. Set attributes + setCommonAttributes(newCtx, span) + setWorkflowAttributes(span) + setAgentAttributes(span, agentName) + + // Capture input if available (propagated from BeforeRun via state or context?) + // Note: BeforeRun captures UserContent, but for nested agents, input might be passed differently. + // For now, if UserContent is available in this context, log it. + if userContent := ctx.UserContent(); userContent != nil { + if jsonIn, err := json.Marshal(userContent); err == nil { + val := string(jsonIn) + span.SetAttributes(attribute.String(AttrGenAIInput, val)) + } + } + + return nil, nil +} + +// AfterAgent is called after an agent execution. +func (p *adkObservabilityPlugin) AfterAgent(ctx agent.CallbackContext) (*genai.Content, error) { + // 1. End the span + if s, _ := ctx.State().Get(stateKeyInvokeAgentSpan); s != nil { + span := s.(trace.Span) + if span.IsRecording() { + // Try to capture output if available in state (propagated from AfterRun or internal execution) + if cached, _ := ctx.State().Get(stateKeyStreamingOutput); cached != nil { + if jsonOut, err := json.Marshal(cached); err == nil { + val := string(jsonOut) + span.SetAttributes(attribute.String(AttrGenAIOutput, val)) + } + } + span.End() + } + } + return nil, nil +} + +func (p *adkObservabilityPlugin) getSpanMetadata(state session.State) *spanMetadata { + val, _ := state.Get(stateKeyMetadata) + if meta, ok := val.(*spanMetadata); ok { + return meta + } + return &spanMetadata{} +} + +func (p *adkObservabilityPlugin) storeSpanMetadata(state session.State, meta *spanMetadata) { + _ = state.Set(stateKeyMetadata, meta) +} + +// sanitizeUTF8 removes or replaces invalid UTF-8 characters from a string +func sanitizeUTF8(s string) string { + // If the string is already valid UTF-8, return it as is + if len(s) == 0 { + return s + } + + // Replace invalid UTF-8 sequences with Unicode replacement character + return string([]rune(s)) +} + +func safeMarshal(v any) string { + if v == nil { + return "" + } + b, err := json.Marshal(v) + if err != nil { + return "" + } + + return string(b) +} + +const ( + stateKeyInvocationSpan = "veadk.observability.invocation_span" + stateKeyInvocationCtx = "veadk.observability.invocation_ctx" + stateKeyInvokeAgentCtx = "veadk.observability.invoke_agent_ctx" + stateKeyInvokeAgentSpan = "veadk.observability.invoke_agent_span" + + stateKeyMetadata = "veadk.observability.metadata" + stateKeyStreamingOutput = "veadk.observability.streaming_output" + stateKeyStreamingSpan = "veadk.observability.streaming_span" +) + +// spanMetadata groups various observational data points in a single structure +// to keep the ADK State clean. +type spanMetadata struct { + StartTime time.Time + FirstTokenTime time.Time + PromptTokens int64 + CandidateTokens int64 + TotalTokens int64 + PrevPromptTokens int64 + PrevCandidateTokens int64 + PrevTotalTokens int64 + ModelName string +} diff --git a/observability/registry.go b/observability/registry.go new file mode 100644 index 0000000..bb83e0d --- /dev/null +++ b/observability/registry.go @@ -0,0 +1,307 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "sync" + "time" + + "github.com/volcengine/veadk-go/log" + "go.opentelemetry.io/otel/trace" +) + +// TraceRegistry manages the mapping between ADK-go's spans and VeADK spans. +// It ensures thread-safe access and proper cleanup of resources. +type TraceRegistry struct { + // adkSpanMap tracks google's adk SpanID (Run/Agent/LLM/Tool) -> VeADK SpanContext + adkSpanMap sync.Map + + // toolCallMap tracks ToolCallID (string) -> *toolCallInfo + // Consolidates: toolCallToVeadkLLMMap, toolInputs, toolOutputs + toolCallMap sync.Map + + // activeInvocationSpans tracks active VeADK invocation spans for shutdown flushing. + activeInvocationSpans sync.Map + + // adkTraceToVeadkTraceMap tracks InternalTraceID -> Associated Resources for cleanup. + resourcesMu sync.RWMutex + adkTraceToVeadkTraceMap map[trace.TraceID]*traceInfos + + // cleanupQueue receives cleanup requests + cleanupQueue chan cleanupRequest + + // shutdownChan signals the cleanup loop to exit + shutdownChan chan struct{} +} + +type cleanupRequest struct { + adkTraceID trace.TraceID + internalRunID trace.SpanID + veadkSpanID trace.SpanID + deadline time.Time +} + +type toolCallInfo struct { + mu sync.RWMutex + parentSC trace.SpanContext +} + +type traceInfos struct { + veadkTraceID trace.TraceID + spanIDs []trace.SpanID + toolCallIDs []string +} + +var ( + // globalRegistry is the singleton instance of TraceRegistry. + globalRegistry *TraceRegistry + once sync.Once +) + +// GetRegistry returns the global TraceRegistry. +func GetRegistry() *TraceRegistry { + once.Do(func() { + globalRegistry = &TraceRegistry{ + adkTraceToVeadkTraceMap: make(map[trace.TraceID]*traceInfos), + cleanupQueue: make(chan cleanupRequest, 512), + shutdownChan: make(chan struct{}), + } + go globalRegistry.cleanupLoop() + }) + return globalRegistry +} + +func (r *TraceRegistry) Shutdown() { + select { + case <-r.shutdownChan: + // Already closed + default: + close(r.shutdownChan) + } +} + +func (r *TraceRegistry) cleanupLoop() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + // Use a slice to store pending requests + var pendingRequests []cleanupRequest + + for { + select { + case <-r.shutdownChan: + return + case req := <-r.cleanupQueue: + pendingRequests = append(pendingRequests, req) + case <-ticker.C: + now := time.Now() + activeRequests := pendingRequests[:0] + for _, req := range pendingRequests { + if now.After(req.deadline) { + // Perform cleanup + r.UnregisterInvocationMapping(req.internalRunID, req.veadkSpanID) + + r.resourcesMu.Lock() + if res, ok := r.adkTraceToVeadkTraceMap[req.adkTraceID]; ok { + for _, sid := range res.spanIDs { + r.adkSpanMap.Delete(sid) + } + for _, tcid := range res.toolCallIDs { + r.toolCallMap.Delete(tcid) + } + delete(r.adkTraceToVeadkTraceMap, req.adkTraceID) + } + r.resourcesMu.Unlock() + } else { + activeRequests = append(activeRequests, req) + } + } + pendingRequests = activeRequests + } + } +} + +func (r *TraceRegistry) getOrCreateTraceInfos(adkTraceID trace.TraceID) *traceInfos { + r.resourcesMu.Lock() + defer r.resourcesMu.Unlock() + + if res, ok := r.adkTraceToVeadkTraceMap[adkTraceID]; ok { + return res + } + res := &traceInfos{} + r.adkTraceToVeadkTraceMap[adkTraceID] = res + return res +} + +// RegisterRunMapping links ADK's internal run span to our veadk invocation span. +func (r *TraceRegistry) RegisterRunMapping(adkSpanID trace.SpanID, adkTraceID trace.TraceID, veadkSC trace.SpanContext, veadkSpan trace.Span) { + if !adkSpanID.IsValid() || !veadkSC.IsValid() { + return + } + r.adkSpanMap.Store(adkSpanID, veadkSC) + r.activeInvocationSpans.Store(veadkSC.SpanID(), veadkSpan) + + if adkTraceID.IsValid() { + res := r.getOrCreateTraceInfos(adkTraceID) + r.resourcesMu.Lock() + res.spanIDs = append(res.spanIDs, adkSpanID) + res.veadkTraceID = veadkSC.TraceID() + r.resourcesMu.Unlock() + } +} + +// RegisterAgentMapping links ADK's internal agent span to our veadk agent span. +func (r *TraceRegistry) RegisterAgentMapping(adkSpanID trace.SpanID, adkTraceID trace.TraceID, veadkSC trace.SpanContext) { + if !adkSpanID.IsValid() || !veadkSC.IsValid() { + return + } + r.adkSpanMap.Store(adkSpanID, veadkSC) + + if adkTraceID.IsValid() { + res := r.getOrCreateTraceInfos(adkTraceID) + r.resourcesMu.Lock() + res.spanIDs = append(res.spanIDs, adkSpanID) + r.resourcesMu.Unlock() + } +} + +// RegisterLLMMapping links ADK's internal LLM span to our veadk LLM span. +func (r *TraceRegistry) RegisterLLMMapping(adkSpanID trace.SpanID, adkTraceID trace.TraceID, veadkSC trace.SpanContext) { + if !adkSpanID.IsValid() || !veadkSC.IsValid() { + return + } + r.adkSpanMap.Store(adkSpanID, veadkSC) + + if adkTraceID.IsValid() { + res := r.getOrCreateTraceInfos(adkTraceID) + r.resourcesMu.Lock() + res.spanIDs = append(res.spanIDs, adkSpanID) + r.resourcesMu.Unlock() + } +} + +// RegisterToolMapping links a tool span (started by ADK) to its veadk parent (LLM call). +func (r *TraceRegistry) RegisterToolMapping(toolSpanID trace.SpanID, veadkParentSC trace.SpanContext) { + if !toolSpanID.IsValid() || !veadkParentSC.IsValid() { + return + } + r.adkSpanMap.Store(toolSpanID, veadkParentSC) +} + +func (r *TraceRegistry) getOrCreateToolCallInfo(toolCallID string) *toolCallInfo { + val, loaded := r.toolCallMap.LoadOrStore(toolCallID, &toolCallInfo{}) + if !loaded { + // New entry + } + return val.(*toolCallInfo) +} + +// RegisterToolCallMapping links a logical tool call ID to its parent LLM span context. +func (r *TraceRegistry) RegisterToolCallMapping(toolCallID string, adkTraceID trace.TraceID, veadkParentSC trace.SpanContext) { + if toolCallID == "" || !veadkParentSC.IsValid() { + return + } + info := r.getOrCreateToolCallInfo(toolCallID) + info.mu.Lock() + info.parentSC = veadkParentSC + info.mu.Unlock() + + if adkTraceID.IsValid() { + res := r.getOrCreateTraceInfos(adkTraceID) + r.resourcesMu.Lock() + res.toolCallIDs = append(res.toolCallIDs, toolCallID) + r.resourcesMu.Unlock() + } +} + +// RegisterTraceMapping records a mapping from an internal adk TraceID to a veadk TraceID. +func (r *TraceRegistry) RegisterTraceMapping(adkTraceID trace.TraceID, veadkTraceID trace.TraceID) { + if !adkTraceID.IsValid() || !veadkTraceID.IsValid() { + return + } + res := r.getOrCreateTraceInfos(adkTraceID) + r.resourcesMu.Lock() + res.veadkTraceID = veadkTraceID + r.resourcesMu.Unlock() +} + +// GetVeadkSpanContext finds the veadk replacement for an adk parent span ID. +func (r *TraceRegistry) GetVeadkSpanContext(adkSpanID trace.SpanID) (trace.SpanContext, bool) { + if val, ok := r.adkSpanMap.Load(adkSpanID); ok { + return val.(trace.SpanContext), true + } + return trace.SpanContext{}, false +} + +// GetVeadkParentContextByToolCallID finds the veadk parent for a tool span by its logical ToolCallID. +func (r *TraceRegistry) GetVeadkParentContextByToolCallID(toolCallID string) (trace.SpanContext, bool) { + if toolCallID == "" { + return trace.SpanContext{}, false + } + if val, ok := r.toolCallMap.Load(toolCallID); ok { + info := val.(*toolCallInfo) + info.mu.RLock() + defer info.mu.RUnlock() + if info.parentSC.IsValid() { + return info.parentSC, true + } + } + return trace.SpanContext{}, false +} + +// GetVeadkTraceID finds the veadk TraceID for an internal TraceID. +func (r *TraceRegistry) GetVeadkTraceID(adkTraceID trace.TraceID) (trace.TraceID, bool) { + r.resourcesMu.RLock() + defer r.resourcesMu.RUnlock() + + if res, ok := r.adkTraceToVeadkTraceMap[adkTraceID]; ok { + return res.veadkTraceID, res.veadkTraceID.IsValid() + } + return trace.TraceID{}, false +} + +// UnregisterInvocationMapping removes run-related mappings. +func (r *TraceRegistry) UnregisterInvocationMapping(adkSpanID trace.SpanID, veadkSpanID trace.SpanID) { + r.adkSpanMap.Delete(adkSpanID) + r.activeInvocationSpans.Delete(veadkSpanID) +} + +// ScheduleCleanup schedules cleanup of all mappings related to an internal TraceID. +// This is typically called when the trace is considered complete. +func (r *TraceRegistry) ScheduleCleanup(adkTraceID trace.TraceID, internalRunID trace.SpanID, veadkSpanID trace.SpanID) { + select { + case r.cleanupQueue <- cleanupRequest{ + adkTraceID: adkTraceID, + internalRunID: internalRunID, + veadkSpanID: veadkSpanID, + deadline: time.Now().Add(2 * time.Minute), + }: + default: + log.Warn("trace cleanup queue is full") + } +} + +// EndAllInvocationSpans ends all currently active invocation spans. +func (r *TraceRegistry) EndAllInvocationSpans() { + r.activeInvocationSpans.Range(func(key, value any) bool { + if span, ok := value.(trace.Span); ok { + if span.IsRecording() { + span.End() + } + } + r.activeInvocationSpans.Delete(key) + return true + }) +} diff --git a/observability/translator.go b/observability/translator.go new file mode 100644 index 0000000..9279032 --- /dev/null +++ b/observability/translator.go @@ -0,0 +1,332 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observability + +import ( + "context" + "encoding/json" + "strings" + + "github.com/volcengine/veadk-go/log" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/instrumentation" + "go.opentelemetry.io/otel/sdk/trace" + oteltrace "go.opentelemetry.io/otel/trace" +) + +var ( + spanInvokeAgent = "invoke_agent" + + gcpVertexAgentLLMRequestName = "gcp.vertex.agent.llm_request" + gcpVertexAgentToolCallArgsName = "gcp.vertex.agent.tool_call_args" + gcpVertexAgentEventID = "gcp.vertex.agent.event_id" + gcpVertexAgentToolResponseName = "gcp.vertex.agent.tool_response" + gcpVertexAgentLLMResponseName = "gcp.vertex.agent.llm_response" + gcpVertexAgentInvocationID = "gcp.vertex.agent.invocation_id" + gcpVertexAgentSessionID = "gcp.vertex.agent.session_id" + + ADKAttributeKeyMap = map[string]string{ + gcpVertexAgentLLMRequestName: AttrInputValue, + gcpVertexAgentLLMResponseName: AttrOutputValue, + gcpVertexAgentToolCallArgsName: AttrGenAIToolInput, + gcpVertexAgentToolResponseName: AttrGenAIToolOutput, + gcpVertexAgentInvocationID: AttrGenAIInvocationId, + gcpVertexAgentSessionID: AttrGenAISessionId, + } +) + +// isMatch returns true if we should keep the span in the final output. +// At this point, we filter out "call_llm" span generated by ADK-go with scope name is "gcp.vertex.agent". +func isMatch(span trace.ReadOnlySpan) bool { + if span.InstrumentationScope().Name == "gcp.vertex.agent" { + name := span.Name() + if name == "call_llm" { + return false + } + } + return true +} + +// VeADKTranslatedExporter wraps a SpanExporter and remaps ADK attributes to standard fields. +type VeADKTranslatedExporter struct { + trace.SpanExporter +} + +func (e *VeADKTranslatedExporter) ExportSpans(ctx context.Context, spans []trace.ReadOnlySpan) error { + translated := make([]trace.ReadOnlySpan, 0, len(spans)) + registry := GetRegistry() + + for _, s := range spans { + if !isMatch(s) { + continue + } + + ts := &translatedSpan{ReadOnlySpan: s} + translated = append(translated, ts) + + // 1. Logic stitching via ToolCallID + toolCallID := "" + for _, kv := range s.Attributes() { + if string(kv.Key) == AttrGenAIToolCallID { + toolCallID = kv.Value.AsString() + break + } + } + + if toolCallID != "" { + if veadkParentSC, ok := registry.GetVeadkParentContextByToolCallID(toolCallID); ok { + // We found a match! Record this TraceID mapping to align other spans in the same trace (like merged spans) + registry.RegisterTraceMapping(s.SpanContext().TraceID(), veadkParentSC.TraceID()) + log.Debug("Matched tool via ToolCallID, established TraceID mapping", + "tool_call_id", toolCallID, + "adk_trace_id", s.SpanContext().TraceID().String(), + "veadk_trace_id", veadkParentSC.TraceID().String(), + ) + } + } + + } + + if len(translated) == 0 { + return nil + } + + return e.SpanExporter.ExportSpans(ctx, translated) +} + +// translatedSpan wraps a ReadOnlySpan and intercepts calls to Attributes(). +type translatedSpan struct { + trace.ReadOnlySpan +} + +func (p *translatedSpan) Attributes() []attribute.KeyValue { + attrs := p.ReadOnlySpan.Attributes() + newAttrs := make([]attribute.KeyValue, 0, len(attrs)+5) // Pre-allocate with some extra space + + // Track existing keys and tool-related fields + existingKeys := make(map[string]bool) + var toolName, toolDesc, toolArgs, toolCallID, toolResponse string + + // First pass: scan for existing keys and raw data + for _, kv := range attrs { + existingKeys[string(kv.Key)] = true + key := string(kv.Key) + + // Collect raw data for reconstruction + switch key { + case AttrGenAIToolName: + toolName = kv.Value.AsString() + case AttrGenAIToolDescription: // Note: ADK uses gen_ai.tool.description + toolDesc = kv.Value.AsString() + case gcpVertexAgentToolCallArgsName: + toolArgs = kv.Value.AsString() + case AttrGenAIToolCallID: + toolCallID = kv.Value.AsString() + case gcpVertexAgentToolResponseName: + toolResponse = kv.Value.AsString() + } + } + + newAttrs = p.processAttributes(attrs, existingKeys) + + // Dynamic Reconstruction: Tool Input/Output from raw attributes + if toolArgs != "" && toolName != "" { + if inputAttrs := p.reconstructToolInput(toolName, toolDesc, toolArgs); inputAttrs != nil { + newAttrs = append(newAttrs, inputAttrs...) + } + } + + if toolResponse != "" && toolCallID != "" { + if outputAttrs := p.reconstructToolOutput(toolName, toolCallID, toolResponse); outputAttrs != nil { + newAttrs = append(newAttrs, outputAttrs...) + } + } + + // Enrich with Span Kind if it's determined to be a tool span + if toolName != "" || toolCallID != "" { + newAttrs = append(newAttrs, attribute.String(AttrGenAISpanKind, SpanKindTool)) + } + + return newAttrs +} + +func (p *translatedSpan) processAttributes(attrs []attribute.KeyValue, existingKeys map[string]bool) []attribute.KeyValue { + newAttrs := make([]attribute.KeyValue, 0, len(attrs)) + for _, kv := range attrs { + key := string(kv.Key) + + // 1. Map ADK internal attributes if not already present in standard form + if strings.HasPrefix(key, "gcp.vertex.agent.") { + targetKey, ok := ADKAttributeKeyMap[key] + if ok { + // Skip if we are going to reconstruct this field + if targetKey == AttrGenAIToolInput || targetKey == AttrGenAIToolOutput { + continue + } + + // Only add mapped key if the target key doesn't already exist in the span + if !existingKeys[targetKey] { + newAttrs = append(newAttrs, attribute.KeyValue{Key: attribute.Key(targetKey), Value: kv.Value}) + } + } + continue + } + + // 2. Patch gen_ai.system if needed + if key == AttrGenAISystem && kv.Value.AsString() == "gcp.vertex.agent" { + kv = attribute.String(AttrGenAISystem, "volcengine") + } + + newAttrs = append(newAttrs, kv) + } + return newAttrs +} + +func (p *translatedSpan) reconstructToolInput(toolName, toolDesc, toolArgs string) []attribute.KeyValue { + var paramsMap map[string]any + if err := json.Unmarshal([]byte(toolArgs), ¶msMap); err == nil { + inputData := map[string]any{ + "name": toolName, + "description": toolDesc, + "parameters": paramsMap, + } + if inputJSON, err := json.Marshal(inputData); err == nil { + val := string(inputJSON) + return []attribute.KeyValue{ + attribute.String(AttrGenAIToolInput, val), + attribute.String(AttrCozeloopInput, val), + attribute.String(AttrGenAIInput, val), + } + } + } + return nil +} + +func (p *translatedSpan) reconstructToolOutput(toolName, toolCallID, toolResponse string) []attribute.KeyValue { + var responseMap map[string]any + // ADK serializes response as map, unmarshal it first + if err := json.Unmarshal([]byte(toolResponse), &responseMap); err == nil { + outputData := map[string]any{ + "id": toolCallID, + "name": toolName, + "response": responseMap, + } + if outputJSON, err := json.Marshal(outputData); err == nil { + val := string(outputJSON) + return []attribute.KeyValue{ + attribute.String(AttrGenAIToolOutput, val), + attribute.String(AttrCozeloopOutput, val), + attribute.String(AttrGenAIOutput, val), + } + } + } + return nil +} + +func (p *translatedSpan) SpanContext() oteltrace.SpanContext { + sc := p.ReadOnlySpan.SpanContext() + registry := GetRegistry() + + toolCallID := "" + for _, kv := range p.ReadOnlySpan.Attributes() { + if string(kv.Key) == AttrGenAIToolCallID { + toolCallID = kv.Value.AsString() + break + } + } + + if veadkParentSC, ok := registry.GetVeadkParentContextByToolCallID(toolCallID); ok { + return oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: veadkParentSC.TraceID(), + SpanID: sc.SpanID(), + TraceFlags: sc.TraceFlags(), + TraceState: sc.TraceState(), + Remote: sc.IsRemote(), + }) + } + + // 2. Try global TraceID mapping (for spans in the same trace without their own tool_call_id) + if veadkParentSC, ok := registry.GetVeadkTraceID(sc.TraceID()); ok { + return oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: veadkParentSC, + SpanID: sc.SpanID(), + TraceFlags: sc.TraceFlags(), + TraceState: sc.TraceState(), + Remote: sc.IsRemote(), + }) + } + + // 3. Fallback to Tool SpanID mapping + if veadkParentSC, ok := registry.GetVeadkSpanContext(sc.SpanID()); ok { + return oteltrace.NewSpanContext(oteltrace.SpanContextConfig{ + TraceID: veadkParentSC.TraceID(), + SpanID: sc.SpanID(), + TraceFlags: sc.TraceFlags(), + TraceState: sc.TraceState(), + Remote: sc.IsRemote(), + }) + } + + return sc +} + +func (p *translatedSpan) Parent() oteltrace.SpanContext { + parent := p.ReadOnlySpan.Parent() + sc := p.ReadOnlySpan.SpanContext() + registry := GetRegistry() + + // 1. Precise Re-parenting based on internal ParentID mapping + if parent.IsValid() { + if veadkSC, ok := registry.GetVeadkSpanContext(parent.SpanID()); ok { + return veadkSC + } + } + + // 2. Try ToolCallID mapping (for tool spans that lost parent context) + toolCallID := "" + for _, kv := range p.ReadOnlySpan.Attributes() { + if string(kv.Key) == AttrGenAIToolCallID { + toolCallID = kv.Value.AsString() + break + } + } + + if manualParentSC, ok := registry.GetVeadkParentContextByToolCallID(toolCallID); ok { + return manualParentSC + } + + // 3. Fallback: Re-parent root spans if we have a direct mapping for this span ID. + if !parent.IsValid() { + if manualSC, ok := registry.GetVeadkSpanContext(sc.SpanID()); ok { + return manualSC + } + } + + return parent +} + +func (p *translatedSpan) InstrumentationScope() instrumentation.Scope { + scope := p.ReadOnlySpan.InstrumentationScope() + // github.com/volcengine/veadk-go is the InstrumentationName defined in observability/constant.go + if scope.Name == "gcp.vertex.agent" || scope.Name == "veadk" || scope.Name == "github.com/volcengine/veadk-go" { + scope.Name = "openinference.instrumentation.veadk" + } + scope.Version = Version + return scope +} + +func (p *translatedSpan) InstrumentationLibrary() instrumentation.Scope { + return p.InstrumentationScope() +}