Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions internal/llminternal/base_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ func (f *Flow) runOneStep(ctx agent.InvocationContext) iter.Seq2[*session.Event,
if ctx.Ended() {
return
}
spans := telemetry.StartTrace(ctx, "call_llm")
sctx, callLLMSpan := telemetry.StartTrace(ctx, "call_llm")
ctx = ctx.WithContext(sctx)
defer callLLMSpan.End()
// Create event to pass to callback state delta
stateDelta := make(map[string]any)
// Calls the LLM.
Expand Down Expand Up @@ -180,7 +182,7 @@ func (f *Flow) runOneStep(ctx agent.InvocationContext) iter.Seq2[*session.Event,

// Build the event and yield.
modelResponseEvent := f.finalizeModelResponseEvent(ctx, resp, tools, stateDelta)
telemetry.TraceLLMCall(spans, ctx, req, modelResponseEvent)
telemetry.TraceLLMCall(callLLMSpan, ctx, req, modelResponseEvent)
if !yield(modelResponseEvent, nil) {
return
}
Expand Down Expand Up @@ -495,13 +497,15 @@ func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[st
toolNames := slices.Collect(maps.Keys(toolsDict))
var result map[string]any
for _, fnCall := range fnCalls {
sctx, span := telemetry.StartTrace(ctx, "execute_tool "+fnCall.Name)
defer span.End()
toolCallCtx := ctx.WithContext(sctx)
var confirmation *toolconfirmation.ToolConfirmation
if toolConfirmations != nil {
confirmation = toolConfirmations[fnCall.ID]
}
toolCtx := toolinternal.NewToolContext(ctx, fnCall.ID, &session.EventActions{StateDelta: make(map[string]any)}, confirmation)
toolCtx := toolinternal.NewToolContext(toolCallCtx, fnCall.ID, &session.EventActions{StateDelta: make(map[string]any)}, confirmation)

spans := telemetry.StartTrace(ctx, "execute_tool "+fnCall.Name)
curTool, found := toolsDict[fnCall.Name]
if !found {
err := newToolNotFoundError(fnCall.Name, toolNames)
Expand Down Expand Up @@ -543,7 +547,7 @@ func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[st
if traceTool == nil {
traceTool = &fakeTool{name: fnCall.Name}
}
telemetry.TraceToolCall(spans, traceTool, fnCall.Args, ev)
telemetry.TraceToolCall(span, traceTool, fnCall.Args, ev)

fnResponseEvents = append(fnResponseEvents, ev)
}
Expand All @@ -552,8 +556,11 @@ func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[st
return mergedEvent, err
}
// this is needed for debug traces of parallel calls
spans := telemetry.StartTrace(ctx, "execute_tool (merged)")
telemetry.TraceMergedToolCalls(spans, mergedEvent)
if mergedEvent != nil {
_, span := telemetry.StartTrace(ctx, "execute_tool (merged)")
telemetry.TraceMergedToolCalls(span, mergedEvent)
span.End()
}
return mergedEvent, nil
}

Expand Down
233 changes: 81 additions & 152 deletions internal/telemetry/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Package telemetry sets up the open telemetry exporters to the ADK.
// Package telemetry implements tracing for ADK.
//
// WARNING: telemetry provided by ADK (internaltelemetry package) may change (e.g. attributes and their names)
// because we're in process to standardize and unify telemetry across all ADKs.
// because we're in process to standardize and unify tracing across all ADKs.
package telemetry

import (
"context"
"encoding/json"
"sync"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
"google.golang.org/genai"

Expand All @@ -35,26 +33,10 @@ import (
"google.golang.org/adk/tool"
)

type tracerProviderHolder struct {
tp trace.TracerProvider
}

type tracerProviderConfig struct {
spanProcessors []sdktrace.SpanProcessor
mu *sync.RWMutex
}

var (
once sync.Once
localTracer tracerProviderHolder
localTracerConfig = tracerProviderConfig{
spanProcessors: []sdktrace.SpanProcessor{},
mu: &sync.RWMutex{},
}
)
var tracer = otel.GetTracerProvider().Tracer(systemName)

const (
SystemName = "gcp.vertex.agent"
systemName = "gcp.vertex.agent"
genAiOperationName = "gen_ai.operation.name"
genAiToolDescription = "gen_ai.tool.description"
genAiToolName = "gen_ai.tool.name"
Expand Down Expand Up @@ -84,165 +66,112 @@ const (
mergeToolName = "(merged tools)"
)

// AddSpanProcessor adds a span processor to the local tracer config.
func AddSpanProcessor(processor sdktrace.SpanProcessor) {
localTracerConfig.mu.Lock()
defer localTracerConfig.mu.Unlock()
localTracerConfig.spanProcessors = append(localTracerConfig.spanProcessors, processor)
}

// RegisterTelemetry sets up the local tracer that will be used to emit traces.
// We use local tracer to respect the global tracer configurations.
func RegisterTelemetry() {
once.Do(func() {
traceProvider := sdktrace.NewTracerProvider()
localTracerConfig.mu.RLock()
spanProcessors := localTracerConfig.spanProcessors
localTracerConfig.mu.RUnlock()
for _, processor := range spanProcessors {
traceProvider.RegisterSpanProcessor(processor)
}
localTracer = tracerProviderHolder{tp: traceProvider}
})
}

// If the global tracer is not set, the default NoopTracerProvider will be used.
// That means that the spans are NOT recording/exporting
// If the local tracer is not set, we'll set up tracer with all registered span processors.
func getTracers() []trace.Tracer {
if localTracer.tp == nil {
RegisterTelemetry()
}
return []trace.Tracer{
localTracer.tp.Tracer(SystemName),
otel.GetTracerProvider().Tracer(SystemName),
}
}

// StartTrace returns two spans to start emitting events, one from global tracer and second from the local.
func StartTrace(ctx context.Context, traceName string) []trace.Span {
tracers := getTracers()
spans := make([]trace.Span, len(tracers))
for i, tracer := range tracers {
_, span := tracer.Start(ctx, traceName)
spans[i] = span
}
return spans
func StartTrace(ctx context.Context, traceName string) (context.Context, trace.Span) {
return tracer.Start(ctx, traceName)
}

// TraceMergedToolCalls traces the tool execution events.
func TraceMergedToolCalls(spans []trace.Span, fnResponseEvent *session.Event) {
if fnResponseEvent == nil {
return
}
for _, span := range spans {
attributes := []attribute.KeyValue{
attribute.String(genAiOperationName, executeToolName),
attribute.String(genAiToolName, mergeToolName),
attribute.String(genAiToolDescription, mergeToolName),
// Setting empty llm request and response (as UI expect these) while not
// applicable for tool_response.
attribute.String(gcpVertexAgentLLMRequestName, "{}"),
attribute.String(gcpVertexAgentLLMRequestName, "{}"),
attribute.String(gcpVertexAgentToolCallArgsName, "N/A"),
attribute.String(gcpVertexAgentEventID, fnResponseEvent.ID),
attribute.String(gcpVertexAgentToolResponseName, safeSerialize(fnResponseEvent)),
}
span.SetAttributes(attributes...)
span.End()
func TraceMergedToolCalls(span trace.Span, fnResponseEvent *session.Event) {
attributes := []attribute.KeyValue{
attribute.String(genAiOperationName, executeToolName),
attribute.String(genAiToolName, mergeToolName),
attribute.String(genAiToolDescription, mergeToolName),
// Setting empty llm request and response (as UI expect these) while not
// applicable for tool_response.
attribute.String(gcpVertexAgentLLMRequestName, "{}"),
attribute.String(gcpVertexAgentLLMResponseName, "{}"),
attribute.String(gcpVertexAgentToolCallArgsName, "N/A"),
attribute.String(gcpVertexAgentEventID, fnResponseEvent.ID),
attribute.String(gcpVertexAgentToolResponseName, safeSerialize(fnResponseEvent)),
}
span.SetAttributes(attributes...)
}

// TraceToolCall traces the tool execution events.
func TraceToolCall(spans []trace.Span, tool tool.Tool, fnArgs map[string]any, fnResponseEvent *session.Event) {
func TraceToolCall(span trace.Span, tool tool.Tool, fnArgs map[string]any, fnResponseEvent *session.Event) {
if fnResponseEvent == nil {
return
}
for _, span := range spans {
attributes := []attribute.KeyValue{
attribute.String(genAiOperationName, executeToolName),
attribute.String(genAiToolName, tool.Name()),
attribute.String(genAiToolDescription, tool.Description()),
// TODO: add tool type

// Setting empty llm request and response (as UI expect these) while not
// applicable for tool_response.
attribute.String(gcpVertexAgentLLMRequestName, "{}"),
attribute.String(gcpVertexAgentLLMRequestName, "{}"),
attribute.String(gcpVertexAgentToolCallArgsName, safeSerialize(fnArgs)),
attribute.String(gcpVertexAgentEventID, fnResponseEvent.ID),
}
attributes := []attribute.KeyValue{
attribute.String(genAiOperationName, executeToolName),
attribute.String(genAiToolName, tool.Name()),
attribute.String(genAiToolDescription, tool.Description()),
// TODO: add tool type

// Setting empty llm request and response (as UI expect these) while not
// applicable for tool_response.
attribute.String(gcpVertexAgentLLMRequestName, "{}"),
attribute.String(gcpVertexAgentLLMResponseName, "{}"),
attribute.String(gcpVertexAgentToolCallArgsName, safeSerialize(fnArgs)),
attribute.String(gcpVertexAgentEventID, fnResponseEvent.ID),
}

toolCallID := "<not specified>"
toolResponse := "<not specified>"
toolCallID := "<not specified>"
toolResponse := "<not specified>"

if fnResponseEvent.LLMResponse.Content != nil {
responseParts := fnResponseEvent.LLMResponse.Content.Parts
if fnResponseEvent.LLMResponse.Content != nil {
responseParts := fnResponseEvent.LLMResponse.Content.Parts

if len(responseParts) > 0 {
functionResponse := responseParts[0].FunctionResponse
if functionResponse != nil {
if functionResponse.ID != "" {
toolCallID = functionResponse.ID
}
if functionResponse.Response != nil {
toolResponse = safeSerialize(functionResponse.Response)
}
if len(responseParts) > 0 {
functionResponse := responseParts[0].FunctionResponse
if functionResponse != nil {
if functionResponse.ID != "" {
toolCallID = functionResponse.ID
}
if functionResponse.Response != nil {
toolResponse = safeSerialize(functionResponse.Response)
}
}
}
}

attributes = append(attributes, attribute.String(genAiToolCallID, toolCallID))
attributes = append(attributes, attribute.String(gcpVertexAgentToolResponseName, toolResponse))
attributes = append(attributes, attribute.String(genAiToolCallID, toolCallID))
attributes = append(attributes, attribute.String(gcpVertexAgentToolResponseName, toolResponse))

span.SetAttributes(attributes...)
span.End()
}
span.SetAttributes(attributes...)
}

// TraceLLMCall fills the call_llm event details.
func TraceLLMCall(spans []trace.Span, agentCtx agent.InvocationContext, llmRequest *model.LLMRequest, event *session.Event) {
func TraceLLMCall(span trace.Span, agentCtx agent.InvocationContext, llmRequest *model.LLMRequest, event *session.Event) {
sessionID := agentCtx.Session().ID()
for _, span := range spans {
attributes := []attribute.KeyValue{
attribute.String(genAiSystemName, SystemName),
attribute.String(genAiRequestModelName, llmRequest.Model),
attribute.String(gcpVertexAgentInvocationID, event.InvocationID),
attribute.String(gcpVertexAgentSessionID, sessionID),
attribute.String(genAiConversationID, sessionID),
attribute.String(gcpVertexAgentEventID, event.ID),
attribute.String(gcpVertexAgentLLMRequestName, safeSerialize(llmRequestToTrace(llmRequest))),
attribute.String(gcpVertexAgentLLMResponseName, safeSerialize(event.LLMResponse)),
}
attributes := []attribute.KeyValue{
attribute.String(genAiSystemName, systemName),
attribute.String(genAiRequestModelName, llmRequest.Model),
attribute.String(gcpVertexAgentInvocationID, event.InvocationID),
attribute.String(gcpVertexAgentSessionID, sessionID),
attribute.String(genAiConversationID, sessionID),
attribute.String(gcpVertexAgentEventID, event.ID),
attribute.String(gcpVertexAgentLLMRequestName, safeSerialize(llmRequestToTrace(llmRequest))),
attribute.String(gcpVertexAgentLLMResponseName, safeSerialize(event.LLMResponse)),
}

if llmRequest.Config.TopP != nil {
attributes = append(attributes, attribute.Float64(genAiRequestTopP, float64(*llmRequest.Config.TopP)))
}
if llmRequest.Config.TopP != nil {
attributes = append(attributes, attribute.Float64(genAiRequestTopP, float64(*llmRequest.Config.TopP)))
}

if llmRequest.Config.MaxOutputTokens != 0 {
attributes = append(attributes, attribute.Int(genAiRequestMaxTokens, int(llmRequest.Config.MaxOutputTokens)))
if llmRequest.Config.MaxOutputTokens != 0 {
attributes = append(attributes, attribute.Int(genAiRequestMaxTokens, int(llmRequest.Config.MaxOutputTokens)))
}
if event.FinishReason != "" {
attributes = append(attributes, attribute.String(genAiResponseFinishReason, string(event.FinishReason)))
}
if event.UsageMetadata != nil {
if event.UsageMetadata.PromptTokenCount > 0 {
attributes = append(attributes, attribute.Int(genAiResponsePromptTokenCount, int(event.UsageMetadata.PromptTokenCount)))
}
if event.FinishReason != "" {
attributes = append(attributes, attribute.String(genAiResponseFinishReason, string(event.FinishReason)))
if event.UsageMetadata.CandidatesTokenCount > 0 {
attributes = append(attributes, attribute.Int(genAiResponseCandidatesTokenCount, int(event.UsageMetadata.CandidatesTokenCount)))
}
if event.UsageMetadata != nil {
if event.UsageMetadata.PromptTokenCount > 0 {
attributes = append(attributes, attribute.Int(genAiResponsePromptTokenCount, int(event.UsageMetadata.PromptTokenCount)))
}
if event.UsageMetadata.CandidatesTokenCount > 0 {
attributes = append(attributes, attribute.Int(genAiResponseCandidatesTokenCount, int(event.UsageMetadata.CandidatesTokenCount)))
}
if event.UsageMetadata.CachedContentTokenCount > 0 {
attributes = append(attributes, attribute.Int(genAiResponseCachedContentTokenCount, int(event.UsageMetadata.CachedContentTokenCount)))
}
if event.UsageMetadata.TotalTokenCount > 0 {
attributes = append(attributes, attribute.Int(genAiResponseTotalTokenCount, int(event.UsageMetadata.TotalTokenCount)))
}
if event.UsageMetadata.CachedContentTokenCount > 0 {
attributes = append(attributes, attribute.Int(genAiResponseCachedContentTokenCount, int(event.UsageMetadata.CachedContentTokenCount)))
}
if event.UsageMetadata.TotalTokenCount > 0 {
attributes = append(attributes, attribute.Int(genAiResponseTotalTokenCount, int(event.UsageMetadata.TotalTokenCount)))
}

span.SetAttributes(attributes...)
span.End()
}

span.SetAttributes(attributes...)
}

func safeSerialize(obj any) string {
Expand Down
3 changes: 0 additions & 3 deletions server/adkrest/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ import (
func NewHandler(config *launcher.Config, sseWriteTimeout time.Duration) http.Handler {
adkExporter := services.NewAPIServerSpanExporter()
processor := sdktrace.NewSimpleSpanProcessor(adkExporter)
// TODO(#479) remove this together with local tracer provider.
// nolint:staticcheck
telemetry.RegisterLocalSpanProcessor(processor)
config.TelemetryOptions = append(config.TelemetryOptions, telemetry.WithSpanProcessors(processor))

router := mux.NewRouter().StrictSlash(true)
Expand Down
13 changes: 0 additions & 13 deletions telemetry/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (

"go.opentelemetry.io/otel"
sdktrace "go.opentelemetry.io/otel/sdk/trace"

internal "google.golang.org/adk/internal/telemetry"
)

// Providers wraps all telemetry providers and provides [Shutdown] function.
Expand Down Expand Up @@ -104,14 +102,3 @@ func New(ctx context.Context, opts ...Option) (*Providers, error) {
}
return newInternal(cfg)
}

// RegisterLocalSpanProcessor registers the span processor to local trace provider instance.
// Any processor should be registered BEFORE any of the events are emitted, otherwise
// the registration will be ignored.
// In addition to the RegisterLocalSpanProcessor function, global trace provider configs
// are respected.
//
// Deprecated: Configure processors via [Option]s passed to [New]. TODO(#479) remove this together with local tracer provider.
func RegisterLocalSpanProcessor(processor sdktrace.SpanProcessor) {
internal.AddSpanProcessor(processor)
}
Loading