Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions agent/workflowagents/parallelagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,14 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] {
defer close(doneChan)

for res := range resultsChan {
if !yield(res.event, res.err) {
shouldContinue := yield(res.event, res.err)

// Signal sub-agent that event processing (including session append) is complete
if res.ackChan != nil {
close(res.ackChan)
}

if !shouldContinue {
break
}
}
Expand All @@ -116,30 +123,37 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] {

func runSubAgent(ctx agent.InvocationContext, agent agent.Agent, results chan<- result, done <-chan bool) error {
for event, err := range agent.Run(ctx) {
ackChan := make(chan struct{})

select {
case <-done:
return nil
case <-ctx.Done():
select {
case <-done:
case results <- result{
err: ctx.Err(),
}:
}
return ctx.Err()
case results <- result{
event: event,
err: err,
event: event,
err: err,
ackChan: ackChan,
}:
if err != nil {
return err
}

// Wait for runner to finish processing before continuing to next iteration
select {
case <-ackChan:
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
}
return nil
}

type result struct {
event *session.Event
err error
event *session.Event
err error
ackChan chan struct{}
}
142 changes: 142 additions & 0 deletions agent/workflowagents/parallelagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,32 @@
"fmt"
"iter"
rand "math/rand/v2"
"net/http"
"path/filepath"
"slices"
"strings"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"google.golang.org/genai"

"google.golang.org/adk/agent"
"google.golang.org/adk/agent/llmagent"
"google.golang.org/adk/agent/workflowagents/loopagent"
"google.golang.org/adk/agent/workflowagents/parallelagent"
"google.golang.org/adk/internal/httprr"
"google.golang.org/adk/internal/testutil"
"google.golang.org/adk/model"
"google.golang.org/adk/model/gemini"
"google.golang.org/adk/runner"
"google.golang.org/adk/session"
"google.golang.org/adk/tool"
"google.golang.org/adk/tool/functiontool"
)

const modelName = "gemini-2.0-flash-exp"

func TestNewParallelAgent(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -224,3 +235,134 @@
}
}
}

func TestParallelAgentWithTools(t *testing.T) {
agent1 := createAgentWithGemini(t, "agent1")
agent2 := createAgentWithGemini(t, "agent2")

parallelAgent, err := parallelagent.New(parallelagent.Config{
AgentConfig: agent.Config{
Name: "parallel_test",
SubAgents: []agent.Agent{agent1, agent2},
},
})
if err != nil {
t.Fatalf("Failed to create parallel agent: %v", err)
}

runner := testutil.NewTestAgentRunner(t, parallelAgent)
stream := runner.Run(t, "test_session", "Search for AI news")

events, err := testutil.CollectEvents(stream)
if err != nil {
t.Fatalf("Agent run failed: %v", err)
}

if len(events) < 2 {
t.Errorf("Expected at least 2 events from parallel agents, got %d", len(events))
}

// Count FunctionCall and FunctionResponse events per branch
branchCalls := make(map[string]int)
branchResponses := make(map[string]int)

for _, ev := range events {
branch := ev.Branch
if ev.LLMResponse.Content != nil {
for _, part := range ev.LLMResponse.Content.Parts {
if part.FunctionCall != nil {
branchCalls[branch]++
}
if part.FunctionResponse != nil {
branchResponses[branch]++
}
}
}
}

for branch, calls := range branchCalls {
responses := branchResponses[branch]
if calls > responses {
t.Errorf("Branch %s: session has %d FunctionCalls but only %d FunctionResponses. "+
"This indicates race condition: agent read session before FunctionResponse was appended.",
branch, calls, responses)
}
}
}

func createAgentWithGemini(t *testing.T, name string) agent.Agent {
t.Helper()

searchTool, err := functiontool.New(
functiontool.Config{
Name: fmt.Sprintf("search_tool_%s", name),
Description: "Search for information on the web",
},
func(ctx tool.Context, args struct{ Query string }) (string, error) {
return fmt.Sprintf("search result for '%s' from %s", args.Query, name), nil
},
)
if err != nil {
t.Fatalf("Failed to create search tool: %v", err)
}

analyzeTool, err := functiontool.New(
functiontool.Config{
Name: fmt.Sprintf("analyze_tool_%s", name),
Description: "Analyze data and return insights",
},
func(ctx tool.Context, args struct{ Data string }) (string, error) {
return fmt.Sprintf("analysis result for '%s' from %s", args.Data, name), nil
},
)
if err != nil {
t.Fatalf("Failed to create analyze tool: %v", err)
}

model := newGeminiModelForTest(t, modelName, name)

a, err := llmagent.New(llmagent.Config{
Name: name,
Description: fmt.Sprintf("Test agent %s that searches for information", name),
Model: model,
Tools: []tool.Tool{searchTool, analyzeTool},
Instruction: "Use the search tool to find information, then provide a brief response.",
})
if err != nil {
t.Fatalf("Failed to create agent %s: %v", name, err)
}

return a
}

func newGeminiModelForTest(t *testing.T, modelName string, agentName string) model.LLM {

Check failure on line 338 in agent/workflowagents/parallelagent/agent_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not properly formatted (gofumpt)
t.Helper()

trace := filepath.Join("testdata", fmt.Sprintf("%s_%s.httprr",
strings.ReplaceAll(t.Name(), "/", "_"), agentName))

apiKey := "fakeKey"
transport, recording := newGeminiTestTransport(t, trace)
if recording {
apiKey = ""
}

model, err := gemini.NewModel(t.Context(), modelName, &genai.ClientConfig{
HTTPClient: &http.Client{Transport: transport},
APIKey: apiKey,
})
if err != nil {
t.Fatalf("Failed to create Gemini model: %v", err)
}
return model
}

func newGeminiTestTransport(t *testing.T, rrfile string) (http.RoundTripper, bool) {
t.Helper()
rr, err := testutil.NewGeminiTransport(rrfile)
if err != nil {
t.Fatal(err)
}
recording, _ := httprr.Recording(rrfile)
return rr, recording
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
httprr trace v1
994 1148
POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1
Host: generativelanguage.googleapis.com
User-Agent: Go-http-client/1.1
Content-Length: 758
Content-Type: application/json

{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK
Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Content-Type: application/json; charset=UTF-8
Date: Mon, 12 Jan 2026 03:32:51 GMT
Server: scaffolding on HTTPServer2
Server-Timing: gfet4t7; dur=661
Vary: Origin
Vary: X-Origin
Vary: Referer
X-Content-Type-Options: nosniff
X-Frame-Options: SAMEORIGIN
X-Xss-Protection: 0

{
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "search_tool_agent1",
"args": {
"Query": "AI news"
}
}
}
],
"role": "model"
},
"finishReason": "STOP",
"avgLogprobs": -1.7506775394495991e-05
}
],
"usageMetadata": {
"promptTokenCount": 41,
"candidatesTokenCount": 9,
"totalTokenCount": 50,
"promptTokensDetails": [
{
"modality": "TEXT",
"tokenCount": 41
}
],
"candidatesTokensDetails": [
{
"modality": "TEXT",
"tokenCount": 9
}
]
},
"modelVersion": "gemini-2.0-flash-exp",
"responseId": "Ymtkad2GMvz_2roPv6WLmQs"
}
1237 1033
POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1
Host: generativelanguage.googleapis.com
User-Agent: Go-http-client/1.1
Content-Length: 1000
Content-Type: application/json

{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"},{"parts":[{"functionCall":{"args":{"Query":"AI news"},"name":"search_tool_agent1"}}],"role":"model"},{"parts":[{"functionResponse":{"name":"search_tool_agent1","response":{"result":"search result for 'AI news' from agent1"}}}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK
Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Content-Type: application/json; charset=UTF-8
Date: Mon, 12 Jan 2026 03:32:52 GMT
Server: scaffolding on HTTPServer2
Server-Timing: gfet4t7; dur=656
Vary: Origin
Vary: X-Origin
Vary: Referer
X-Content-Type-Options: nosniff
X-Frame-Options: SAMEORIGIN
X-Xss-Protection: 0

{
"candidates": [
{
"content": {
"parts": [
{
"text": "I have searched for AI news."
}
],
"role": "model"
},
"finishReason": "STOP",
"avgLogprobs": -0.03738286665507725
}
],
"usageMetadata": {
"promptTokenCount": 67,
"candidatesTokenCount": 7,
"totalTokenCount": 74,
"promptTokensDetails": [
{
"modality": "TEXT",
"tokenCount": 67
}
],
"candidatesTokensDetails": [
{
"modality": "TEXT",
"tokenCount": 7
}
]
},
"modelVersion": "gemini-2.0-flash-exp",
"responseId": "Y2tkaaH-NJKd0-kP9puPwAc"
}
Loading
Loading