Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions internal/context/callback_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,20 @@ func NewCallbackContextWithDelta(ctx agent.InvocationContext, stateDelta map[str
func newCallbackContext(ctx agent.InvocationContext, stateDelta map[string]any) *callbackContext {
rCtx := NewReadonlyContext(ctx)
eventActions := &session.EventActions{StateDelta: stateDelta}

var artifacts *internalArtifacts
if ctx.Artifacts() != nil {
artifacts = &internalArtifacts{
Artifacts: ctx.Artifacts(),
eventActions: eventActions,
}
}

return &callbackContext{
ReadonlyContext: rCtx,
invocationCtx: ctx,
eventActions: eventActions,
artifacts: &internalArtifacts{
Artifacts: ctx.Artifacts(),
eventActions: eventActions,
},
artifacts: artifacts,
}
}

Expand All @@ -77,6 +83,9 @@ type callbackContext struct {
}

func (c *callbackContext) Artifacts() agent.Artifacts {
if c.artifacts == nil {
return nil
}
return c.artifacts
}

Expand Down
16 changes: 12 additions & 4 deletions internal/toolinternal/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,20 @@ func NewToolContext(ctx agent.InvocationContext, functionCallID string, actions
}
cbCtx := contextinternal.NewCallbackContextWithDelta(ctx, actions.StateDelta)

var artifacts *internalArtifacts
if ctx.Artifacts() != nil {
artifacts = &internalArtifacts{
Artifacts: ctx.Artifacts(),
eventActions: actions,
}
}

return &toolContext{
CallbackContext: cbCtx,
invocationContext: ctx,
functionCallID: functionCallID,
eventActions: actions,
artifacts: &internalArtifacts{
Artifacts: ctx.Artifacts(),
eventActions: actions,
},
artifacts: artifacts,
}
}

Expand All @@ -81,6 +86,9 @@ type toolContext struct {
}

func (c *toolContext) Artifacts() agent.Artifacts {
if c.artifacts == nil {
return nil
}
return c.artifacts
}

Expand Down
6 changes: 6 additions & 0 deletions tool/loadartifactstool/load_artifacts_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package loadartifactstool
import (
"context"
"encoding/json"
"errors"
"fmt"

"golang.org/x/sync/errgroup"
Expand All @@ -32,6 +33,8 @@ import (
"google.golang.org/adk/tool"
)

var ErrArtifactServiceNotInitialized = errors.New("artifact service is not initialized")

// artifactsTool is a tool that loads artifacts and adds them to the session.
type artifactsTool struct {
name string
Expand Down Expand Up @@ -118,6 +121,9 @@ func (t *artifactsTool) Run(ctx tool.Context, args any) (map[string]any, error)
// ProcessRequest processes the LLM request. It packs the tool, appends initial
// instructions, and processes any load artifacts function calls.
func (t *artifactsTool) ProcessRequest(ctx tool.Context, req *model.LLMRequest) error {
if ctx.Artifacts() == nil {
return ErrArtifactServiceNotInitialized
}
if err := toolutils.PackTool(req, t); err != nil {
return err
}
Expand Down
22 changes: 22 additions & 0 deletions tool/loadartifactstool/load_artifacts_tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package loadartifactstool_test

import (
"errors"
"strings"
"testing"

Expand Down Expand Up @@ -293,3 +294,24 @@ func createToolContext(t *testing.T) tool.Context {

return toolinternal.NewToolContext(ctx, "", nil)
}

func TestLoadArtifactsTool_ProcessRequest_NilArtifactService(t *testing.T) {
loadArtifactsTool := loadartifactstool.New()

ctx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{
Artifacts: nil,
})
tc := toolinternal.NewToolContext(ctx, "", nil)

llmRequest := &model.LLMRequest{}

requestProcessor, ok := loadArtifactsTool.(toolinternal.RequestProcessor)
if !ok {
t.Fatal("loadArtifactsTool does not implement RequestProcessor")
}

err := requestProcessor.ProcessRequest(tc, llmRequest)
if !errors.Is(err, loadartifactstool.ErrArtifactServiceNotInitialized) {
t.Errorf("ProcessRequest() with nil artifact service returned error = %v, want %v", err, loadartifactstool.ErrArtifactServiceNotInitialized)
}
}