From c98cffc848afd5057f1385211a66933d97ba3b1b Mon Sep 17 00:00:00 2001 From: Wada Yasuhiro <19990967+yasu89@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:46:00 +0900 Subject: [PATCH] Support Gemini Computer use Signed-off-by: Wada Yasuhiro <19990967+yasu89@users.noreply.github.com> --- .../McpToolsConfigurationTests.java | 5 +- .../McpToolCallbackParameterlessToolIT.java | 5 +- .../ToolCallingAutoConfigurationTests.java | 9 ++- .../ai/mcp/AsyncMcpToolCallback.java | 7 +- .../springframework/ai/mcp/McpToolUtils.java | 8 +- .../ai/mcp/SyncMcpToolCallback.java | 7 +- .../ai/mcp/AsyncMcpToolCallbackTest.java | 23 +++--- .../ai/mcp/SyncMcpToolCallbackTests.java | 29 +++---- .../ai/mcp/ToolUtilsTests.java | 3 +- .../redis/RedisChatMemoryRepository.java | 68 ++++++++++------ .../ai/google/genai/GoogleGenAiChatModel.java | 68 ++++++++++++++-- .../google/genai/GoogleGenAiChatOptions.java | 80 ++++++++++++++++++- .../GoogleGenAiComputerUseEnvironment.java | 43 ++++++++++ .../MistralAiChatCompletionRequestTests.java | 5 +- .../ai/ollama/OllamaChatRequestTests.java | 5 +- .../chat/OpenAiSdkChatOptionsTests.java | 17 ++-- .../ai/openai/ChatCompletionRequestTests.java | 5 +- .../client/DefaultChatClientUtilsTests.java | 5 +- ...tChatClientObservationConventionTests.java | 3 +- .../pages/api/chat/google-genai-chat.adoc | 3 + .../ai/chat/messages/ToolResponseMessage.java | 10 ++- .../model/tool/DefaultToolCallingManager.java | 7 +- .../springframework/ai/tool/ToolCallback.java | 5 +- .../tool/augment/AugmentedToolCallback.java | 5 +- .../DefaultToolCallResultConverter.java | 14 ++-- ...efaultToolExecutionExceptionProcessor.java | 4 +- .../ai/tool/execution/ToolCallResult.java | 77 ++++++++++++++++++ .../execution/ToolCallResultConverter.java | 10 +-- .../ToolExecutionExceptionProcessor.java | 8 +- .../tool/function/FunctionToolCallback.java | 5 +- .../ai/tool/method/MethodToolCallback.java | 5 +- .../ToolCallingContentObservationFilter.java | 6 +- .../ToolCallingObservationContext.java | 13 +-- ...ltChatModelObservationConventionTests.java | 5 +- .../tool/DefaultToolCallingManagerIT.java | 5 +- .../tool/DefaultToolCallingManagerTest.java | 33 ++++---- .../tool/DefaultToolCallingManagerTests.java | 7 +- .../tool/ToolCallingChatOptionsTests.java | 5 +- .../augment/AugmentedToolCallbackTest.java | 31 +++---- .../DefaultToolCallResultConverterTests.java | 54 ++++++------- ...tToolExecutionExceptionProcessorTests.java | 16 ++-- .../function/FunctionToolCallbackTest.java | 13 +-- ...thodToolCallbackExceptionHandlingTest.java | 5 +- .../MethodToolCallbackGenericTypesTest.java | 17 ++-- ...ToolCallingObservationConventionTests.java | 3 +- ...lCallingContentObservationFilterTests.java | 11 +-- .../ToolCallingObservationContextTests.java | 5 +- 47 files changed, 549 insertions(+), 228 deletions(-) create mode 100644 models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/common/GoogleGenAiComputerUseEnvironment.java create mode 100644 spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolCallResult.java diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java index b9db603a0bc..00e9effb5d8 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java @@ -36,6 +36,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.resolution.ToolCallbackResolver; import org.springframework.ai.util.json.schema.JsonSchemaGenerator; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -210,8 +211,8 @@ public ToolDefinition getToolDefinition() { } @Override - public String call(String toolInput) { - return "~~ not implemented ~~"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("~~ not implemented ~~").build(); } } }; } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/McpToolCallbackParameterlessToolIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/McpToolCallbackParameterlessToolIT.java index 3bb10339208..3fd8469fcdf 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/McpToolCallbackParameterlessToolIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/McpToolCallbackParameterlessToolIT.java @@ -46,6 +46,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -184,8 +185,8 @@ void testMcpServerClientIntegrationWithIncompleteSchemaSyncTool() { assertThat(schemaMap.get("additionalProperties")).isEqualTo(false); // Test that the callback can be called successfully - String result = toolCallback.call("{}"); - assertThat(result).isNotNull().contains("Current time:"); + ToolCallResult result = toolCallback.call("{}"); + assertThat(result.content()).isNotNull().contains("Current time:"); }); stopHttpServer(httpServer); diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java index 709a7d1d155..73f89ac1467 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java @@ -36,6 +36,7 @@ import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; import org.springframework.ai.tool.function.FunctionToolCallback; @@ -162,8 +163,8 @@ void throwExceptionOnErrorDefault() { var exception = new ToolExecutionException(toolDefinition, cause); // Default behavior should not throw exception - String result = toolExecutionExceptionProcessor.process(exception); - assertThat(result).isEqualTo("Test error"); + ToolCallResult result = toolExecutionExceptionProcessor.process(exception); + assertThat(result.content()).isEqualTo("Test error"); }); } @@ -364,8 +365,8 @@ public ToolExecutionExceptionProcessor toolExecutionExceptionProcessor() { static class CustomToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor { @Override - public String process(ToolExecutionException exception) { - return "Custom error handling"; + public ToolCallResult process(ToolExecutionException exception) { + return ToolCallResult.builder().content("Custom error handling").build(); } } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 5a482dbd970..afc8ddb1ddc 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -31,6 +31,7 @@ import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -100,12 +101,12 @@ public String getOriginalToolName() { } @Override - public String call(String toolCallInput) { + public ToolCallResult call(String toolCallInput) { return this.call(toolCallInput, null); } @Override - public String call(String toolCallInput, @Nullable ToolContext toolContext) { + public ToolCallResult call(String toolCallInput, @Nullable ToolContext toolContext) { // Handle the possible null parameter situation in streaming mode. if (!StringUtils.hasText(toolCallInput)) { @@ -143,7 +144,7 @@ public String call(String toolCallInput, @Nullable ToolContext toolContext) { throw new ToolExecutionException(this.getToolDefinition(), new IllegalStateException("Error calling tool: " + response.content())); } - return ModelOptionsUtils.toJsonString(response.content()); + return ToolCallResult.builder().content(ModelOptionsUtils.toJsonString(response.content())).build(); } /** diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index 2cd4b0de3ae..9a0a559ffdd 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -42,6 +42,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.util.json.schema.JsonSchemaUtils; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; @@ -257,14 +258,15 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> { try { - String callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request.arguments()), + ToolCallResult callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request.arguments()), new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchangeOrContext))); if (mimeType != null && mimeType.toString().startsWith("image")) { McpSchema.Annotations annotations = new McpSchema.Annotations(List.of(Role.ASSISTANT), null); return new McpSchema.CallToolResult( - List.of(new McpSchema.ImageContent(annotations, callResult, mimeType.toString())), false); + List.of(new McpSchema.ImageContent(annotations, callResult.content(), mimeType.toString())), + false); } - return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult.content())), false); } catch (Exception e) { return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true); diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index 012ded0cdc2..17c69e3aaf6 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -30,6 +30,7 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -101,12 +102,12 @@ public String getOriginalToolName() { } @Override - public String call(String toolCallInput) { + public ToolCallResult call(String toolCallInput) { return this.call(toolCallInput, null); } @Override - public String call(String toolCallInput, @Nullable ToolContext toolContext) { + public ToolCallResult call(String toolCallInput, @Nullable ToolContext toolContext) { // Handle the possible null parameter situation in streaming mode. if (!StringUtils.hasText(toolCallInput)) { @@ -142,7 +143,7 @@ public String call(String toolCallInput, @Nullable ToolContext toolContext) { throw new ToolExecutionException(this.getToolDefinition(), new IllegalStateException("Error calling tool: " + response.content())); } - return ModelOptionsUtils.toJsonString(response.content()); + return ToolCallResult.builder().content(ModelOptionsUtils.toJsonString(response.content())).build(); } /** diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java index 347c5428902..7fdfe682572 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java @@ -30,6 +30,7 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import static org.assertj.core.api.Assertions.assertThat; @@ -98,10 +99,10 @@ void callShouldSucceedWithValidInput() { .prefixedToolName("prefixed_testTool") .build(); - String result = callback.call("{\"param\":\"value\"}"); + ToolCallResult result = callback.call("{\"param\":\"value\"}"); // Assert - assertThat(result).contains("Success response"); + assertThat(result.content()).contains("Success response"); // Verify the correct tool name was used in the request ArgumentCaptor requestCaptor = ArgumentCaptor @@ -128,10 +129,10 @@ void callShouldHandleNullInput() { .prefixedToolName("testTool") .build(); - String result = callback.call(null); + ToolCallResult result = callback.call(null); // Assert - assertThat(result).contains("Success with empty input"); + assertThat(result.content()).contains("Success with empty input"); // Verify empty JSON object was used ArgumentCaptor requestCaptor = ArgumentCaptor @@ -156,10 +157,10 @@ void callShouldHandleEmptyInput() { .prefixedToolName("testTool") .build(); - String result = callback.call(""); + ToolCallResult result = callback.call(""); // Assert - assertThat(result).contains("Success with empty input"); + assertThat(result.content()).contains("Success with empty input"); // Verify empty JSON object was used ArgumentCaptor requestCaptor = ArgumentCaptor @@ -187,10 +188,10 @@ void callShouldIncludeToolContext() { .prefixedToolName("testTool") .build(); - String result = callback.call("{\"param\":\"value\"}", toolContext); + ToolCallResult result = callback.call("{\"param\":\"value\"}", toolContext); // Assert - assertThat(result).contains("Success with context"); + assertThat(result.content()).contains("Success with context"); // Verify the context was included in the request ArgumentCaptor requestCaptor = ArgumentCaptor @@ -323,11 +324,11 @@ void callShouldHandleComplexJsonResponse() { .prefixedToolName("testTool") .build(); - String result = callback.call("{\"input\":\"test\"}"); + ToolCallResult result = callback.call("{\"input\":\"test\"}"); // Assert - assertThat(result).contains("Part 1"); - assertThat(result).contains("Part 2"); + assertThat(result.content()).contains("Part 1"); + assertThat(result.content()).contains("Part 2"); } } diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java index 7f81162eb56..d121127b9fa 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java @@ -31,6 +31,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import static org.assertj.core.api.Assertions.assertThat; @@ -93,7 +94,7 @@ void callShouldHandleJsonInputAndOutput() { .toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter()) .build(); - String response = callback.call("{\"param\":\"value\"}"); + ToolCallResult response = callback.call("{\"param\":\"value\"}"); assertThat(response).isNotNull(); } @@ -111,7 +112,7 @@ void callShouldHandleToolContext() { .toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter()) .build(); - String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar"))); + ToolCallResult response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar"))); assertThat(response).isNotNull(); } @@ -130,16 +131,16 @@ void callShouldHandleNullOrEmptyInput() { .build(); // Test with null input - String responseNull = callback.call(null); - assertThat(responseNull).isEqualTo("[]"); + ToolCallResult responseNull = callback.call(null); + assertThat(responseNull.content()).isEqualTo("[]"); // Test with empty string input - String responseEmpty = callback.call(""); - assertThat(responseEmpty).isEqualTo("[]"); + ToolCallResult responseEmpty = callback.call(""); + assertThat(responseEmpty.content()).isEqualTo("[]"); // Test with whitespace-only input - String responseWhitespace = callback.call(" "); - assertThat(responseWhitespace).isEqualTo("[]"); + ToolCallResult responseWhitespace = callback.call(" "); + assertThat(responseWhitespace.content()).isEqualTo("[]"); } @Test @@ -197,9 +198,9 @@ void callShouldHandleEmptyResponse() { .toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter()) .build(); - String response = callback.call("{\"param\":\"value\"}"); + ToolCallResult response = callback.call("{\"param\":\"value\"}"); - assertThat(response).isEqualTo("[]"); + assertThat(response.content()).isEqualTo("[]"); } @Test @@ -218,10 +219,10 @@ void callShouldHandleMultipleContentItems() { .toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter()) .build(); - String response = callback.call("{\"param\":\"value\"}"); + ToolCallResult response = callback.call("{\"param\":\"value\"}"); assertThat(response).isNotNull(); - assertThat(response).isEqualTo("[{\"text\":\"First content\"},{\"text\":\"Second content\"}]"); + assertThat(response.content()).isEqualTo("[{\"text\":\"First content\"},{\"text\":\"Second content\"}]"); } @Test @@ -239,10 +240,10 @@ void callShouldHandleNonTextContent() { .toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter()) .build(); - String response = callback.call("{\"param\":\"value\"}"); + ToolCallResult response = callback.call("{\"param\":\"value\"}"); assertThat(response).isNotNull(); - assertThat(response).isEqualTo("[{\"data\":\"base64data\",\"mimeType\":\"image/png\"}]"); + assertThat(response.content()).isEqualTo("[{\"data\":\"base64data\",\"mimeType\":\"image/png\"}]"); } @Test diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java index 8864155f600..0baf886a601 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java @@ -39,6 +39,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -321,7 +322,7 @@ private ToolCallback createMockToolCallback(String name, String result) { .inputSchema("{}") .build(); when(callback.getToolDefinition()).thenReturn(definition); - when(callback.call(anyString(), any())).thenReturn(result); + when(callback.call(anyString(), any())).thenReturn(ToolCallResult.builder().content(result).build()); return callback; } diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepository.java index 81f51836757..a1ebb018cd1 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepository.java @@ -1,5 +1,16 @@ package org.springframework.ai.chat.memory.repository.redis; +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + import com.google.gson.Gson; import com.google.gson.JsonArray; import com.google.gson.JsonElement; @@ -7,22 +18,15 @@ import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.memory.ChatMemoryRepository; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.content.Media; -import org.springframework.ai.content.MediaContent; -import org.springframework.util.Assert; -import org.springframework.util.MimeType; import redis.clients.jedis.JedisPooled; import redis.clients.jedis.Pipeline; import redis.clients.jedis.json.Path2; -import redis.clients.jedis.search.*; +import redis.clients.jedis.search.Document; +import redis.clients.jedis.search.FTCreateParams; +import redis.clients.jedis.search.IndexDataType; +import redis.clients.jedis.search.Query; import redis.clients.jedis.search.RediSearchUtil; +import redis.clients.jedis.search.SearchResult; import redis.clients.jedis.search.aggr.AggregationBuilder; import redis.clients.jedis.search.aggr.AggregationResult; import redis.clients.jedis.search.aggr.Reducers; @@ -34,16 +38,17 @@ import redis.clients.jedis.search.schemafields.TagField; import redis.clients.jedis.search.schemafields.TextField; -import java.net.URI; -import java.time.Duration; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Base64; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.content.Media; +import org.springframework.ai.content.MediaContent; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; /** * Redis implementation of {@link ChatMemoryRepository} using Redis (JSON + Query Engine). @@ -457,7 +462,15 @@ else if (MessageType.TOOL.toString().equals(type)) { String responseData = responseJson.has("responseData") ? responseJson.get("responseData").getAsString() : ""; - toolResponses.add(new ToolResponseMessage.ToolResponse(id, name, responseData)); + Map responseMetadata = new HashMap<>(); + if (responseJson.has("metadata") && responseJson.get("metadata").isJsonObject()) { + JsonObject responseMetadataJson = responseJson.getAsJsonObject("metadata"); + responseMetadataJson.entrySet().forEach(entry -> { + responseMetadata.put(entry.getKey(), gson.fromJson(entry.getValue(), Object.class)); + }); + } + toolResponses + .add(new ToolResponseMessage.ToolResponse(id, name, responseData, responseMetadata)); } } @@ -1160,7 +1173,14 @@ else if (MessageType.TOOL.toString().equals(type)) { String responseData = responseJson.has("responseData") ? responseJson.get("responseData").getAsString() : ""; - toolResponses.add(new ToolResponseMessage.ToolResponse(id, name, responseData)); + Map responseMetadata = new HashMap<>(); + if (responseJson.has("metadata") && responseJson.get("metadata").isJsonObject()) { + JsonObject responseMetadataJson = responseJson.getAsJsonObject("metadata"); + responseMetadataJson.entrySet().forEach(entry -> { + responseMetadata.put(entry.getKey(), gson.fromJson(entry.getValue(), Object.class)); + }); + } + toolResponses.add(new ToolResponseMessage.ToolResponse(id, name, responseData, responseMetadata)); } } diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index e6378477c84..92b35e55a71 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -28,11 +28,13 @@ import com.google.genai.Client; import com.google.genai.ResponseStream; import com.google.genai.types.Candidate; +import com.google.genai.types.ComputerUse; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.FunctionResponsePart; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GenerateContentResponse; import com.google.genai.types.GoogleSearch; @@ -328,16 +330,29 @@ else if (message instanceof AssistantMessage assistantMessage) { return parts; } else if (message instanceof ToolResponseMessage toolResponseMessage) { - - return toolResponseMessage.getResponses() - .stream() - .map(response -> Part.builder() + return toolResponseMessage.getResponses().stream().map(response -> { + var partsJson = response.metadata().get("gemini.functionResponse.partsJson"); + List functionResponseParts = new ArrayList<>(); + if (partsJson instanceof String[] partsJsonArray) { + for (String partJson : partsJsonArray) { + functionResponseParts.add(FunctionResponsePart.fromJson(partJson)); + } + } + else if (partsJson instanceof Iterable partsJsonIterable) { + for (Object part : partsJsonIterable) { + if (part instanceof String partJson) { + functionResponseParts.add(FunctionResponsePart.fromJson(partJson)); + } + } + } + return Part.builder() .functionResponse(FunctionResponse.builder() .name(response.name()) .response(parseJsonToMap(response.responseData())) + .parts(functionResponseParts) .build()) - .build()) - .toList(); + .build(); + }).toList(); } else { throw new IllegalArgumentException("Gemini doesn't support message type: " + message.getClass()); @@ -515,6 +530,13 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setGoogleSearchRetrieval(ModelOptionsUtils.mergeOption( runtimeOptions.getGoogleSearchRetrieval(), this.defaultOptions.getGoogleSearchRetrieval())); + requestOptions.setComputerUse(ModelOptionsUtils.mergeOption(runtimeOptions.getComputerUse(), + this.defaultOptions.getComputerUse())); + requestOptions.setComputerUseEnvironment(ModelOptionsUtils.mergeOption( + runtimeOptions.getComputerUseEnvironment(), this.defaultOptions.getComputerUseEnvironment())); + requestOptions.setComputerUseExcludedPredefinedFunctions( + ModelOptionsUtils.mergeOption(runtimeOptions.getComputerUseExcludedPredefinedFunctions(), + this.defaultOptions.getComputerUseExcludedPredefinedFunctions())); requestOptions.setSafetySettings(ModelOptionsUtils.mergeOption(runtimeOptions.getSafetySettings(), this.defaultOptions.getSafetySettings())); requestOptions @@ -527,6 +549,10 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setToolContext(this.defaultOptions.getToolContext()); requestOptions.setGoogleSearchRetrieval(this.defaultOptions.getGoogleSearchRetrieval()); + requestOptions.setComputerUse(this.defaultOptions.getComputerUse()); + requestOptions.setComputerUseEnvironment(this.defaultOptions.getComputerUseEnvironment()); + requestOptions.setComputerUseExcludedPredefinedFunctions( + this.defaultOptions.getComputerUseExcludedPredefinedFunctions()); requestOptions.setSafetySettings(this.defaultOptions.getSafetySettings()); requestOptions.setLabels(this.defaultOptions.getLabels()); } @@ -806,6 +832,19 @@ GeminiRequest createGeminiRequest(Prompt prompt) { tools.add(googleSearchRetrievalTool); } + if (requestOptions.getComputerUse()) { + var computerUseBuilder = ComputerUse.builder(); + if (requestOptions.getComputerUseEnvironment() != null) { + computerUseBuilder.environment(requestOptions.getComputerUseEnvironment().getEnvironment()); + } + if (!CollectionUtils.isEmpty(requestOptions.getComputerUseExcludedPredefinedFunctions())) { + computerUseBuilder + .excludedPredefinedFunctions(requestOptions.getComputerUseExcludedPredefinedFunctions()); + } + final var computerUseTool = Tool.builder().computerUse(computerUseBuilder.build()).build(); + tools.add(computerUseTool); + } + if (!CollectionUtils.isEmpty(tools)) { configBuilder.tools(tools); } @@ -1135,6 +1174,23 @@ public enum ChatModel implements ChatModelDescription { */ GEMINI_2_5_FLASH_LIGHT("gemini-2.5-flash-lite"), + /** + * gemini-2.5-computer-use-preview is a specialized model for browser and + * GUI automation. It allows agents to understand screen content and generate UI + * actions such as mouse clicks, keyboard input, and navigation to control + * applications. + *

+ * Inputs: Text, Images - 128,000 tokens | Outputs: Text - 64,000 tokens + *

+ * Knowledge cutoff: October 2025 + *

+ * Model ID: gemini-2.5-computer-use-preview-10-2025 + *

+ * See: gemini-2.5-computer-use-preview-10-2025 + */ + GEMINI_2_5_COMPUTER_USE_PREVIEW("gemini-2.5-computer-use-preview-10-2025"), + GEMINI_3_PRO_PREVIEW("gemini-3-pro-preview"); public final String value; diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java index a95f2da7d79..0845d8a7e1f 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java @@ -31,6 +31,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.google.genai.GoogleGenAiChatModel.ChatModel; +import org.springframework.ai.google.genai.common.GoogleGenAiComputerUseEnvironment; import org.springframework.ai.google.genai.common.GoogleGenAiSafetySetting; import org.springframework.ai.google.genai.common.GoogleGenAiThinkingLevel; import org.springframework.ai.model.tool.StructuredOutputChatOptions; @@ -202,6 +203,24 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions, Structure @JsonIgnore private Boolean googleSearchRetrieval = false; + /** + * Enable Gemini Computer Use tool. + */ + @JsonIgnore + private Boolean computerUse = false; + + /** + * Gemini Computer Use environment. + */ + @JsonIgnore + private GoogleGenAiComputerUseEnvironment computerUseEnvironment; + + /** + * Predefined functions to exclude from Computer Use tool. + */ + @JsonIgnore + private List computerUseExcludedPredefinedFunctions; + @JsonIgnore private List safetySettings = new ArrayList<>(); @@ -229,6 +248,9 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti options.setResponseSchema(fromOptions.getResponseSchema()); options.setToolNames(fromOptions.getToolNames()); options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); + options.setComputerUse(fromOptions.getComputerUse()); + options.setComputerUseEnvironment(fromOptions.getComputerUseEnvironment()); + options.setComputerUseExcludedPredefinedFunctions(fromOptions.getComputerUseExcludedPredefinedFunctions()); options.setSafetySettings(fromOptions.getSafetySettings()); options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); options.setToolContext(fromOptions.getToolContext()); @@ -458,6 +480,30 @@ public void setGoogleSearchRetrieval(Boolean googleSearchRetrieval) { this.googleSearchRetrieval = googleSearchRetrieval; } + public Boolean getComputerUse() { + return this.computerUse; + } + + public void setComputerUse(Boolean computerUse) { + this.computerUse = computerUse; + } + + public GoogleGenAiComputerUseEnvironment getComputerUseEnvironment() { + return this.computerUseEnvironment; + } + + public void setComputerUseEnvironment(GoogleGenAiComputerUseEnvironment computerUseEnvironment) { + this.computerUseEnvironment = computerUseEnvironment; + } + + public List getComputerUseExcludedPredefinedFunctions() { + return this.computerUseExcludedPredefinedFunctions; + } + + public void setComputerUseExcludedPredefinedFunctions(List computerUseExcludedPredefinedFunctions) { + this.computerUseExcludedPredefinedFunctions = computerUseExcludedPredefinedFunctions; + } + public List getSafetySettings() { return this.safetySettings; } @@ -507,6 +553,10 @@ public boolean equals(Object o) { return false; } return Objects.equals(this.googleSearchRetrieval, that.googleSearchRetrieval) + && Objects.equals(this.computerUse, that.computerUse) + && this.computerUseEnvironment == that.computerUseEnvironment + && Objects.equals(this.computerUseExcludedPredefinedFunctions, + that.computerUseExcludedPredefinedFunctions) && Objects.equals(this.stopSequences, that.stopSequences) && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) && Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount) @@ -530,7 +580,8 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.includeThoughts, this.thinkingLevel, this.maxOutputTokens, this.model, this.responseMimeType, this.responseSchema, - this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings, + this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.computerUse, + this.computerUseEnvironment, this.computerUseExcludedPredefinedFunctions, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels); } @@ -543,8 +594,10 @@ public String toString() { + ", candidateCount=" + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" - + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", labels=" + this.labels - + '}'; + + this.googleSearchRetrieval + ", computerUse=" + this.computerUse + ", computerUseEnvironment=" + + this.computerUseEnvironment + ", computerUseExcludedPredefinedFunctions=" + + this.computerUseExcludedPredefinedFunctions + ", safetySettings=" + this.safetySettings + ", labels=" + + this.labels + '}'; } @Override @@ -656,6 +709,27 @@ public Builder googleSearchRetrieval(boolean googleSearch) { return this; } + public Builder computerUse(boolean computerUse) { + this.options.computerUse = computerUse; + return this; + } + + public Builder computerUseEnvironment(GoogleGenAiComputerUseEnvironment environment) { + this.options.computerUseEnvironment = environment; + return this; + } + + public Builder computerUseExcludedPredefinedFunctions(List excludedPredefinedFunctions) { + this.options.computerUseExcludedPredefinedFunctions = excludedPredefinedFunctions; + return this; + } + + public Builder computerUseExcludedPredefinedFunctions(String... excludedPredefinedFunctions) { + Assert.notNull(excludedPredefinedFunctions, "excludedPredefinedFunctions must not be null"); + this.options.computerUseExcludedPredefinedFunctions = Arrays.asList(excludedPredefinedFunctions); + return this; + } + public Builder safetySettings(List safetySettings) { Assert.notNull(safetySettings, "safetySettings must not be null"); this.options.safetySettings = safetySettings; diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/common/GoogleGenAiComputerUseEnvironment.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/common/GoogleGenAiComputerUseEnvironment.java new file mode 100644 index 00000000000..4c2d776b370 --- /dev/null +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/common/GoogleGenAiComputerUseEnvironment.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.google.genai.common; + +import com.google.genai.types.Environment; + +/** + * Supported environments for Gemini Computer Use. + * + * @author Wada Yasuhiro + * @since 1.1.0 + */ +public enum GoogleGenAiComputerUseEnvironment { + + ENVIRONMENT_UNSPECIFIED(Environment.Known.ENVIRONMENT_UNSPECIFIED), + + ENVIRONMENT_BROWSER(Environment.Known.ENVIRONMENT_BROWSER); + + private final Environment.Known environmentEnum; + + GoogleGenAiComputerUseEnvironment(Environment.Known environmentEnum) { + this.environmentEnum = environmentEnum; + } + + public Environment getEnvironment() { + return new Environment(this.environmentEnum); + } + +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java index c770d63dca6..b54b7593ec1 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java @@ -37,6 +37,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -313,8 +314,8 @@ public ToolDefinition getToolDefinition() { } @Override - public String call(String toolInput) { - return "Mission accomplished!"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("Mission accomplished!").build(); } } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 14d658ce403..e99c5f2ebc9 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -36,6 +36,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; @@ -280,8 +281,8 @@ public ToolDefinition getToolDefinition() { } @Override - public String call(String toolInput) { - return "Mission accomplished!"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("Mission accomplished!").build(); } } diff --git a/models/spring-ai-openai-sdk/src/test/java/org/springframework/ai/openaisdk/chat/OpenAiSdkChatOptionsTests.java b/models/spring-ai-openai-sdk/src/test/java/org/springframework/ai/openaisdk/chat/OpenAiSdkChatOptionsTests.java index f1ef3494681..aee4ad49bf1 100644 --- a/models/spring-ai-openai-sdk/src/test/java/org/springframework/ai/openaisdk/chat/OpenAiSdkChatOptionsTests.java +++ b/models/spring-ai-openai-sdk/src/test/java/org/springframework/ai/openaisdk/chat/OpenAiSdkChatOptionsTests.java @@ -29,6 +29,7 @@ import org.springframework.ai.openaisdk.OpenAiSdkChatOptions; import org.springframework.ai.openaisdk.OpenAiSdkChatOptions.StreamOptions; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -449,8 +450,8 @@ public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() } @Override - public String call(String toolInput) { - return "result1"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("result1").build(); } }; @@ -465,8 +466,8 @@ public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() } @Override - public String call(String toolInput) { - return "result2"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("result2").build(); } }; @@ -492,8 +493,8 @@ public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() } @Override - public String call(String toolInput) { - return "result"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("result").build(); } }; List callbacks = List.of(callback); @@ -638,8 +639,8 @@ public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() } @Override - public String call(String toolInput) { - return "result"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("result").build(); } }; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index 08d67a5f100..f4c6463b846 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -29,6 +29,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.function.FunctionToolCallback; import static org.assertj.core.api.Assertions.assertThat; @@ -240,8 +241,8 @@ public ToolDefinition getToolDefinition() { } @Override - public String call(String toolInput) { - return "Mission accomplished!"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("Mission accomplished!").build(); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java index 7b8d5491f90..5bb72cff9b3 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -34,6 +34,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.metadata.ToolMetadata; import static org.assertj.core.api.Assertions.assertThat; @@ -501,8 +502,8 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { - return "Mission accomplished!"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("Mission accomplished!").build(); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 6e782214908..d8afb1a7f7e 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -41,6 +41,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; @@ -91,7 +92,7 @@ public ToolDefinition getToolDefinition() { } @Override - public String call(String functionInput) { + public ToolCallResult call(String functionInput) { // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'call'"); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/google-genai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/google-genai-chat.adoc index 95d81faf763..7b46bc98a46 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/google-genai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/google-genai-chat.adoc @@ -108,6 +108,9 @@ The prefix `spring.ai.google.genai.chat` is the property prefix that lets you co | spring.ai.google.genai.chat.options.model | Supported https://ai.google.dev/gemini-api/docs/models[Google GenAI Chat models] to use include `gemini-2.0-flash`, `gemini-2.0-flash-lite`, `gemini-pro`, and `gemini-1.5-flash`. | gemini-2.0-flash | spring.ai.google.genai.chat.options.response-mime-type | Output response mimetype of the generated candidate text. | `text/plain`: (default) Text output or `application/json`: JSON response. | spring.ai.google.genai.chat.options.google-search-retrieval | Use Google search Grounding feature | `true` or `false`, default `false`. +| spring.ai.google.genai.chat.options.computer-use | Enable Gemini Computer Use tool. | `true` or `false`, default `false`. +| spring.ai.google.genai.chat.options.computer-use-environment | Computer Use environment. Valid values: `ENVIRONMENT_BROWSER`, `ENVIRONMENT_UNSPECIFIED`. | - +| spring.ai.google.genai.chat.options.computer-use-excluded-predefined-functions | List of predefined functions to exclude from Computer Use tool. | - | spring.ai.google.genai.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the generative. | - | spring.ai.google.genai.chat.options.top-k | The maximum number of tokens to consider when sampling. The generative uses combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. | - | spring.ai.google.genai.chat.options.top-p | The maximum cumulative probability of tokens to consider when sampling. The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | - diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java index 4d369a9da63..a59a3f9a40c 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java @@ -70,7 +70,15 @@ public String toString() { + ", metadata=" + this.metadata + '}'; } - public record ToolResponse(String id, String name, String responseData) { + public record ToolResponse(String id, String name, String responseData, Map metadata) { + + public ToolResponse { + metadata = Map.copyOf(metadata); + } + + public ToolResponse(String id, String name, String responseData) { + this(id, name, responseData, Map.of()); + } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index ba350b2e18b..38505823659 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -37,6 +37,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; import org.springframework.ai.tool.observation.DefaultToolCallingObservationConvention; @@ -233,11 +234,11 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess .toolCallArguments(finalToolInputArguments) .build(); - String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL + ToolCallResult toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - String toolResult; + ToolCallResult toolResult; try { toolResult = toolCallback.call(finalToolInputArguments, toolContext); } @@ -249,7 +250,7 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess }); toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, - toolCallResult != null ? toolCallResult : "")); + toolCallResult.content() == null ? "" : toolCallResult.content(), toolCallResult.metadata())); } return new InternalToolExecutionResult(ToolResponseMessage.builder().responses(toolResponses).build(), diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/ToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/ToolCallback.java index 85aaaf1595b..6ab18e9f827 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/ToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/ToolCallback.java @@ -22,6 +22,7 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.metadata.ToolMetadata; /** @@ -50,13 +51,13 @@ default ToolMetadata getToolMetadata() { * Execute tool with the given input and return the result to send back to the AI * model. */ - String call(String toolInput); + ToolCallResult call(String toolInput); /** * Execute tool with the given input and context, and return the result to send back * to the AI model. */ - default String call(String toolInput, @Nullable ToolContext toolContext) { + default ToolCallResult call(String toolInput, @Nullable ToolContext toolContext) { if (toolContext != null && !toolContext.getContext().isEmpty()) { logger.info("By default the tool context is not used, " + "override the method 'call(String toolInput, ToolContext toolcontext)' to support the use of tool context." diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/augment/AugmentedToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/augment/AugmentedToolCallback.java index e69474f67a5..96e588448ea 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/augment/AugmentedToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/augment/AugmentedToolCallback.java @@ -27,6 +27,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.augment.ToolInputSchemaAugmenter.AugmentedArgumentType; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.util.json.JsonParser; import org.springframework.util.Assert; @@ -103,12 +104,12 @@ public ToolDefinition getToolDefinition() { } @Override - public String call(String toolInput) { + public ToolCallResult call(String toolInput) { return this.delegate.call(this.handleAugmentedArguments(toolInput)); } @Override - public String call(String toolInput, @Nullable ToolContext tooContext) { + public ToolCallResult call(String toolInput, @Nullable ToolContext tooContext) { return this.delegate.call(this.handleAugmentedArguments(toolInput), tooContext); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java index 80ffa71deb5..1a91e2cefac 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java @@ -42,10 +42,10 @@ public final class DefaultToolCallResultConverter implements ToolCallResultConve private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class); @Override - public String convert(@Nullable Object result, @Nullable Type returnType) { + public ToolCallResult convert(@Nullable Object result, @Nullable Type returnType) { if (returnType == Void.TYPE) { logger.debug("The tool has no return type. Converting to conventional response."); - return JsonParser.toJson("Done"); + return ToolCallResult.builder().content(JsonParser.toJson("Done")).build(); } if (result instanceof RenderedImage) { final var buf = new ByteArrayOutputStream(1024 * 4); @@ -53,14 +53,18 @@ public String convert(@Nullable Object result, @Nullable Type returnType) { ImageIO.write((RenderedImage) result, "PNG", buf); } catch (IOException e) { - return "Failed to convert tool result to a base64 image: " + e.getMessage(); + return ToolCallResult.builder() + .content("Failed to convert tool result to a base64 image: " + e.getMessage()) + .build(); } final var imgB64 = Base64.getEncoder().encodeToString(buf.toByteArray()); - return JsonParser.toJson(Map.of("mimeType", "image/png", "data", imgB64)); + return ToolCallResult.builder() + .content(JsonParser.toJson(Map.of("mimeType", "image/png", "data", imgB64))) + .build(); } else { logger.debug("Converting tool result to JSON."); - return JsonParser.toJson(result); + return ToolCallResult.builder().content(JsonParser.toJson(result)).build(); } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolExecutionExceptionProcessor.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolExecutionExceptionProcessor.java index 11fc180ecee..a908876bcf2 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolExecutionExceptionProcessor.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolExecutionExceptionProcessor.java @@ -55,7 +55,7 @@ public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow, } @Override - public String process(ToolExecutionException exception) { + public ToolCallResult process(ToolExecutionException exception) { Assert.notNull(exception, "exception cannot be null"); Throwable cause = exception.getCause(); if (cause instanceof RuntimeException runtimeException) { @@ -79,7 +79,7 @@ public String process(ToolExecutionException exception) { } logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(), message, exception); - return message; + return ToolCallResult.builder().content(message).build(); } public static Builder builder() { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolCallResult.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolCallResult.java new file mode 100644 index 00000000000..ae10b08b300 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolCallResult.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.execution; + +import java.util.Map; + +import org.jspecify.annotations.Nullable; + +/** + * Represents the result of a tool call, including optional metadata. + * + * @author Wada Yasuhiro + * @since 1.1.0 + */ +public final class ToolCallResult { + + private final @Nullable String content; + + private final Map metadata; + + private ToolCallResult(@Nullable String content, Map metadata) { + this.content = content; + this.metadata = Map.copyOf(metadata); + } + + public @Nullable String content() { + return this.content; + } + + public Map metadata() { + return this.metadata; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private @Nullable String content; + + private Map metadata = Map.of(); + + private Builder() { + } + + public Builder content(@Nullable String content) { + this.content = content; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public ToolCallResult build() { + return new ToolCallResult(this.content, this.metadata); + } + + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java index 0a7caf0206b..366ac5762aa 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java @@ -21,8 +21,8 @@ import org.jspecify.annotations.Nullable; /** - * A functional interface to convert tool call results to a String that can be sent back - * to the AI model. + * A functional interface to convert tool call results to a {@link ToolCallResult} that + * can be sent back to the AI model. * * @author Thomas Vitale * @since 1.0.0 @@ -31,9 +31,9 @@ public interface ToolCallResultConverter { /** - * Given an Object returned by a tool, convert it to a String compatible with the - * given class type. + * Given an Object returned by a tool, convert it to a {@link ToolCallResult} + * compatible with the given class type. */ - String convert(@Nullable Object result, @Nullable Type returnType); + ToolCallResult convert(@Nullable Object result, @Nullable Type returnType); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java index 95dc4d98bcd..dc3b957f390 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java @@ -18,8 +18,8 @@ /** * A functional interface to process a {@link ToolExecutionException} by either converting - * the error message to a String that can be sent back to the AI model or throwing an - * exception to be handled by the caller. + * the error to a {@link ToolCallResult} that can be sent back to the AI model or throwing + * an exception to be handled by the caller. * * @author Thomas Vitale * @since 1.0.0 @@ -28,9 +28,9 @@ public interface ToolExecutionExceptionProcessor { /** - * Convert an exception thrown by a tool to a String that can be sent back to the AI + * Convert an exception thrown by a tool to a result that can be sent back to the AI * model or throw an exception to be handled by the caller. */ - String process(ToolExecutionException exception); + ToolCallResult process(ToolExecutionException exception); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java index 579cdee1435..10a71e517db 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java @@ -31,6 +31,7 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolCallResultConverter; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.metadata.ToolMetadata; @@ -91,12 +92,12 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { + public ToolCallResult call(String toolInput) { return call(toolInput, null); } @Override - public String call(String toolInput, @Nullable ToolContext toolContext) { + public ToolCallResult call(String toolInput, @Nullable ToolContext toolContext) { Assert.hasText(toolInput, "toolInput cannot be null or empty"); logger.debug("Starting execution of tool: {}", this.toolDefinition.name()); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java index 9c1b776ed03..56761569c16 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -33,6 +33,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolCallResultConverter; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.metadata.ToolMetadata; @@ -90,12 +91,12 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { + public ToolCallResult call(String toolInput) { return call(toolInput, null); } @Override - public String call(String toolInput, @Nullable ToolContext toolContext) { + public ToolCallResult call(String toolInput, @Nullable ToolContext toolContext) { Assert.hasText(toolInput, "toolInput cannot be null or empty"); logger.debug("Starting execution of tool: {}", this.toolDefinition.name()); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/observation/ToolCallingContentObservationFilter.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/observation/ToolCallingContentObservationFilter.java index 9f1e5547e0b..ed74d927c8e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/observation/ToolCallingContentObservationFilter.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/observation/ToolCallingContentObservationFilter.java @@ -39,11 +39,11 @@ public Observation.Context map(Observation.Context context) { .addHighCardinalityKeyValue(ToolCallingObservationDocumentation.HighCardinalityKeyNames.TOOL_CALL_ARGUMENTS .withValue(toolCallArguments)); - String toolCallResult = toolCallingObservationContext.getToolCallResult(); - if (toolCallResult != null) { + var toolCallResult = toolCallingObservationContext.getToolCallResult(); + if (toolCallResult != null && toolCallResult.content() != null) { toolCallingObservationContext .addHighCardinalityKeyValue(ToolCallingObservationDocumentation.HighCardinalityKeyNames.TOOL_CALL_RESULT - .withValue(toolCallResult)); + .withValue(toolCallResult.content())); } return toolCallingObservationContext; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/observation/ToolCallingObservationContext.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/observation/ToolCallingObservationContext.java index c1bf2ff9f2d..b1e380c12d1 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/observation/ToolCallingObservationContext.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/observation/ToolCallingObservationContext.java @@ -23,6 +23,7 @@ import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.util.Assert; @@ -43,10 +44,10 @@ public final class ToolCallingObservationContext extends Observation.Context { private final String toolCallArguments; - private @Nullable String toolCallResult; + private @Nullable ToolCallResult toolCallResult; private ToolCallingObservationContext(ToolDefinition toolDefinition, ToolMetadata toolMetadata, - @Nullable String toolCallArguments, @Nullable String toolCallResult) { + @Nullable String toolCallArguments, @Nullable ToolCallResult toolCallResult) { Assert.notNull(toolDefinition, "toolDefinition cannot be null"); Assert.notNull(toolMetadata, "toolMetadata cannot be null"); @@ -72,11 +73,11 @@ public String getToolCallArguments() { return this.toolCallArguments; } - public @Nullable String getToolCallResult() { + public @Nullable ToolCallResult getToolCallResult() { return this.toolCallResult; } - public void setToolCallResult(@Nullable String toolCallResult) { + public void setToolCallResult(@Nullable ToolCallResult toolCallResult) { this.toolCallResult = toolCallResult; } @@ -92,7 +93,7 @@ public static final class Builder { private @Nullable String toolCallArguments; - private @Nullable String toolCallResult; + private @Nullable ToolCallResult toolCallResult; private Builder() { } @@ -112,7 +113,7 @@ public Builder toolCallArguments(String toolCallArguments) { return this; } - public Builder toolCallResult(@Nullable String toolCallResult) { + public Builder toolCallResult(@Nullable ToolCallResult toolCallResult) { this.toolCallResult = toolCallResult; return this; } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java index da12a249065..740d80b6f26 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java @@ -36,6 +36,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.metadata.ToolMetadata; import static org.assertj.core.api.Assertions.assertThat; @@ -248,8 +249,8 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { - return "Mission accomplished!"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("Mission accomplished!").build(); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java index 0aa31889cdf..30780f1b90c 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java @@ -35,6 +35,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.ai.tool.observation.DefaultToolCallingObservationConvention; import org.springframework.ai.tool.observation.ToolCallingObservationDocumentation; @@ -154,8 +155,8 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { - return "Mission accomplished!"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("Mission accomplished!").build(); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java index bd60639c323..1c95febf100 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java @@ -30,6 +30,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.metadata.ToolMetadata; import static org.assertj.core.api.Assertions.assertThat; @@ -60,11 +61,11 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { + public ToolCallResult call(String toolInput) { // Verify the input is not null or empty assertThat(toolInput).isNotNull(); assertThat(toolInput).isNotEmpty(); - return "{\"result\": \"success\"}"; + return ToolCallResult.builder().content("{\"result\": \"success\"}").build(); } }; @@ -117,11 +118,11 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { + public ToolCallResult call(String toolInput) { // Verify the input is not null or empty assertThat(toolInput).isNotNull(); assertThat(toolInput).isNotEmpty(); - return "{\"result\": \"success\"}"; + return ToolCallResult.builder().content("{\"result\": \"success\"}").build(); } }; @@ -174,8 +175,8 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { - return "{\"result\": \"tool1_success\"}"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("{\"result\": \"tool1_success\"}").build(); } }; @@ -195,8 +196,8 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { - return "{\"result\": \"tool2_success\"}"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("{\"result\": \"tool2_success\"}").build(); } }; @@ -251,10 +252,10 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { + public ToolCallResult call(String toolInput) { assertThat(toolInput).contains("nested"); assertThat(toolInput).contains("array"); - return "{\"result\": \"processed\"}"; + return ToolCallResult.builder().content("{\"result\": \"processed\"}").build(); } }; @@ -297,10 +298,10 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { + public ToolCallResult call(String toolInput) { // Should still receive some input even if malformed assertThat(toolInput).isNotNull(); - return "{\"result\": \"handled\"}"; + return ToolCallResult.builder().content("{\"result\": \"handled\"}").build(); } }; @@ -344,8 +345,8 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { - return null; // Return null + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content(null).build(); // Return null } }; @@ -387,8 +388,8 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { - return "{\"result\": \"success\"}"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("{\"result\": \"success\"}").build(); } }; diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java index 54abeef0ffc..caed2d8a7b1 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java @@ -34,6 +34,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; import org.springframework.ai.tool.metadata.ToolMetadata; @@ -423,8 +424,8 @@ public ToolMetadata getToolMetadata() { } @Override - public String call(String toolInput) { - return "Mission accomplished!"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("Mission accomplished!").build(); } } @@ -443,7 +444,7 @@ public ToolDefinition getToolDefinition() { } @Override - public String call(String toolInput) { + public ToolCallResult call(String toolInput) { throw new ToolExecutionException(this.toolDefinition, new IllegalStateException("You failed this city!")); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java index 6d5d599dccd..fbd8f7e88cf 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -25,6 +25,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -190,8 +191,8 @@ public ToolDefinition getToolDefinition() { } @Override - public String call(String toolInput) { - return "Mission accomplished!"; + public ToolCallResult call(String toolInput) { + return ToolCallResult.builder().content("Mission accomplished!").build(); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/augment/AugmentedToolCallbackTest.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/augment/AugmentedToolCallbackTest.java index 0a8e5964a7c..be1071b616d 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/augment/AugmentedToolCallbackTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/augment/AugmentedToolCallbackTest.java @@ -32,6 +32,7 @@ import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -268,7 +269,7 @@ void shouldCallDelegateWithProcessedInput() { when(mockToolDefinition.name()).thenReturn("testTool"); when(mockToolDefinition.description()).thenReturn("Test tool description"); when(mockToolDefinition.inputSchema()).thenReturn(originalSchema); - when(mockDelegate.call(anyString())).thenReturn("success"); + when(mockDelegate.call(anyString())).thenReturn(ToolCallResult.builder().content("success").build()); AtomicReference> capturedArgs = new AtomicReference<>(); Consumer> consumer = capturedArgs::set; @@ -285,10 +286,10 @@ void shouldCallDelegateWithProcessedInput() { """; // When - String result = callback.call(toolInput); + ToolCallResult result = callback.call(toolInput); // Then - assertEquals("success", result); + assertEquals("success", result.content()); verify(mockDelegate).call(toolInput); TestArguments args = capturedArgs.get().arguments(); @@ -316,7 +317,8 @@ void shouldCallDelegateWithContext() { when(mockToolDefinition.name()).thenReturn("testTool"); when(mockToolDefinition.description()).thenReturn("Test tool description"); when(mockToolDefinition.inputSchema()).thenReturn(originalSchema); - when(mockDelegate.call(anyString(), any(ToolContext.class))).thenReturn("success"); + when(mockDelegate.call(anyString(), any(ToolContext.class))) + .thenReturn(ToolCallResult.builder().content("success").build()); Consumer> consumer = args -> { }; @@ -333,10 +335,10 @@ void shouldCallDelegateWithContext() { """; // When - String result = callback.call(toolInput, mockToolContext); + ToolCallResult result = callback.call(toolInput, mockToolContext); // Then - assertEquals("success", result); + assertEquals("success", result.content()); verify(mockDelegate).call(toolInput, mockToolContext); } @@ -360,7 +362,7 @@ void shouldRemoveExtendedArgumentsWhenConfigured() throws Exception { when(mockToolDefinition.name()).thenReturn("testTool"); when(mockToolDefinition.description()).thenReturn("Test tool description"); when(mockToolDefinition.inputSchema()).thenReturn(originalSchema); - when(mockDelegate.call(anyString())).thenReturn("success"); + when(mockDelegate.call(anyString())).thenReturn(ToolCallResult.builder().content("success").build()); Consumer> consumer = args -> { }; @@ -411,7 +413,7 @@ void shouldPreserveExtendedArgumentsWhenNotConfiguredToRemove() throws Exception when(mockToolDefinition.name()).thenReturn("testTool"); when(mockToolDefinition.description()).thenReturn("Test tool description"); when(mockToolDefinition.inputSchema()).thenReturn(originalSchema); - when(mockDelegate.call(anyString())).thenReturn("success"); + when(mockDelegate.call(anyString())).thenReturn(ToolCallResult.builder().content("success").build()); Consumer> consumer = args -> { }; @@ -462,7 +464,7 @@ void shouldHandleNullConsumerGracefully() { when(mockToolDefinition.name()).thenReturn("testTool"); when(mockToolDefinition.description()).thenReturn("Test tool description"); when(mockToolDefinition.inputSchema()).thenReturn(originalSchema); - when(mockDelegate.call(anyString())).thenReturn("success"); + when(mockDelegate.call(anyString())).thenReturn(ToolCallResult.builder().content("success").build()); AugmentedToolCallback callback = new AugmentedToolCallback<>(mockDelegate, TestArguments.class, null, false); @@ -477,8 +479,8 @@ void shouldHandleNullConsumerGracefully() { // When & Then - should not throw exception assertDoesNotThrow(() -> { - String result = callback.call(toolInput); - assertEquals("success", result); + ToolCallResult result = callback.call(toolInput); + assertEquals("success", result.content()); }); } @@ -510,7 +512,8 @@ void shouldHandleCompleteWorkflowWithConsumerProcessing() { when(mockToolDefinition.name()).thenReturn("productTool"); when(mockToolDefinition.description()).thenReturn("Product management tool"); when(mockToolDefinition.inputSchema()).thenReturn(originalSchema); - when(mockDelegate.call(anyString())).thenReturn("Product processed successfully"); + when(mockDelegate.call(anyString())) + .thenReturn(ToolCallResult.builder().content("Product processed successfully").build()); AtomicReference> processedArgs = new AtomicReference<>(); Consumer> consumer = processedArgs::set; @@ -526,10 +529,10 @@ void shouldHandleCompleteWorkflowWithConsumerProcessing() { """; // When - String result = callback.call(toolInput); + ToolCallResult result = callback.call(toolInput); // Then - assertEquals("Product processed successfully", result); + assertEquals("Product processed successfully", result.content()); // Verify consumer was called with correct arguments SimpleArguments args = processedArgs.get().arguments(); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java index a7c5288fdb0..142d7fd89b4 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java @@ -45,33 +45,33 @@ class DefaultToolCallResultConverterTests { @Test void convertWithNullReturnTypeShouldReturn() { - String result = this.converter.convert(null, null); - assertThat(result).isEqualTo("null"); + ToolCallResult result = this.converter.convert(null, null); + assertThat(result.content()).isEqualTo("null"); } @Test void convertVoidReturnTypeShouldReturnDoneJson() { - String result = this.converter.convert(null, void.class); - assertThat(result).isEqualTo("\"Done\""); + ToolCallResult result = this.converter.convert(null, void.class); + assertThat(result.content()).isEqualTo("\"Done\""); } @Test void convertStringReturnTypeShouldReturnJson() { - String result = this.converter.convert("test", String.class); - assertThat(result).isEqualTo("\"test\""); + ToolCallResult result = this.converter.convert("test", String.class); + assertThat(result.content()).isEqualTo("\"test\""); } @Test void convertNullReturnValueShouldReturnNullJson() { - String result = this.converter.convert(null, String.class); - assertThat(result).isEqualTo("null"); + ToolCallResult result = this.converter.convert(null, String.class); + assertThat(result.content()).isEqualTo("null"); } @Test void convertObjectReturnTypeShouldReturnJson() { TestObject testObject = new TestObject("test", 42); - String result = this.converter.convert(testObject, TestObject.class); - assertThat(result).containsIgnoringWhitespaces(""" + ToolCallResult result = this.converter.convert(testObject, TestObject.class); + assertThat(result.content()).containsIgnoringWhitespaces(""" "name": "test" """).containsIgnoringWhitespaces(""" "value": 42 @@ -81,8 +81,8 @@ void convertObjectReturnTypeShouldReturnJson() { @Test void convertCollectionReturnTypeShouldReturnJson() { List testList = List.of("one", "two", "three"); - String result = this.converter.convert(testList, List.class); - assertThat(result).isEqualTo(""" + ToolCallResult result = this.converter.convert(testList, List.class); + assertThat(result.content()).isEqualTo(""" ["one","two","three"] """.trim()); } @@ -90,8 +90,8 @@ void convertCollectionReturnTypeShouldReturnJson() { @Test void convertMapReturnTypeShouldReturnJson() { Map testMap = Map.of("one", 1, "two", 2); - String result = this.converter.convert(testMap, Map.class); - assertThat(result).containsIgnoringWhitespaces(""" + ToolCallResult result = this.converter.convert(testMap, Map.class); + assertThat(result.content()).containsIgnoringWhitespaces(""" "one": 1 """).containsIgnoringWhitespaces(""" "two": 2 @@ -108,9 +108,9 @@ void convertImageShouldReturnBase64Image() throws IOException { g.setColor(Color.WHITE); g.fillRect(0, 0, 64, 64); g.dispose(); - String result = this.converter.convert(img, BufferedImage.class); + ToolCallResult result = this.converter.convert(img, BufferedImage.class); - var b64Struct = JsonParser.fromJson(result, Base64Wrapper.class); + var b64Struct = JsonParser.fromJson(result.content(), Base64Wrapper.class); assertThat(b64Struct.mimeType).isEqualTo(MimeTypeUtils.IMAGE_PNG); assertThat(b64Struct.data).isNotNull(); @@ -125,30 +125,30 @@ void convertImageShouldReturnBase64Image() throws IOException { @Test void convertEmptyCollectionsShouldReturnEmptyJson() { - assertThat(this.converter.convert(List.of(), List.class)).isEqualTo("[]"); - assertThat(this.converter.convert(Map.of(), Map.class)).isEqualTo("{}"); - assertThat(this.converter.convert(new String[0], String[].class)).isEqualTo("[]"); + assertThat(this.converter.convert(List.of(), List.class).content()).isEqualTo("[]"); + assertThat(this.converter.convert(Map.of(), Map.class).content()).isEqualTo("{}"); + assertThat(this.converter.convert(new String[0], String[].class).content()).isEqualTo("[]"); } @Test void convertRecordReturnTypeShouldReturnJson() { TestRecord record = new TestRecord("recordName", 1); - String result = this.converter.convert(record, TestRecord.class); + ToolCallResult result = this.converter.convert(record, TestRecord.class); - assertThat(result).containsIgnoringWhitespaces("\"recordName\""); - assertThat(result).containsIgnoringWhitespaces("1"); + assertThat(result.content()).containsIgnoringWhitespaces("\"recordName\""); + assertThat(result.content()).containsIgnoringWhitespaces("1"); } @Test void convertSpecialCharactersInStringsShouldEscapeJson() { String specialChars = "Test with \"quotes\", newlines\n, tabs\t, and backslashes\\"; - String result = this.converter.convert(specialChars, String.class); + ToolCallResult result = this.converter.convert(specialChars, String.class); // Should properly escape JSON special characters - assertThat(result).contains("\\\"quotes\\\""); - assertThat(result).contains("\\n"); - assertThat(result).contains("\\t"); - assertThat(result).contains("\\\\"); + assertThat(result.content()).contains("\\\"quotes\\\""); + assertThat(result.content()).contains("\\n"); + assertThat(result.content()).contains("\\t"); + assertThat(result.content()).contains("\\\\"); } record TestRecord(String name, int value) { diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/execution/DefaultToolExecutionExceptionProcessorTests.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/execution/DefaultToolExecutionExceptionProcessorTests.java index cf0aca7022f..52900eb44d7 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/execution/DefaultToolExecutionExceptionProcessorTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/execution/DefaultToolExecutionExceptionProcessorTests.java @@ -55,9 +55,9 @@ class DefaultToolExecutionExceptionProcessorTests { void processReturnsMessage() { DefaultToolExecutionExceptionProcessor processor = DefaultToolExecutionExceptionProcessor.builder().build(); - String result = processor.process(this.toolExecutionException); + ToolCallResult result = processor.process(this.toolExecutionException); - assertThat(result).isEqualTo(this.toolException.getMessage()); + assertThat(result.content()).isEqualTo(this.toolException.getMessage()); } @Test @@ -66,9 +66,9 @@ void processReturnsFallbackMessageWhenNull() { ToolExecutionException exception = new ToolExecutionException(this.toolDefinition, new IllegalStateException()); - String result = processor.process(exception); + ToolCallResult result = processor.process(exception); - assertThat(result).isEqualTo("Exception occurred in tool: toolName (IllegalStateException)"); + assertThat(result.content()).isEqualTo("Exception occurred in tool: toolName (IllegalStateException)"); } @Test @@ -77,9 +77,9 @@ void processReturnsFallbackMessageWhenBlank() { ToolExecutionException exception = new ToolExecutionException(this.toolDefinition, new RuntimeException(" ")); - String result = processor.process(exception); + ToolCallResult result = processor.process(exception); - assertThat(result).isEqualTo("Exception occurred in tool: toolName (RuntimeException)"); + assertThat(result.content()).isEqualTo("Exception occurred in tool: toolName (RuntimeException)"); } @Test @@ -125,9 +125,9 @@ void processRethrowsOnlySelectExceptions() { ToolExecutionException exception = new ToolExecutionException(this.toolDefinition, new RuntimeException("This exception was not rethrown")); - String result = processor.process(exception); + ToolCallResult result = processor.process(exception); - assertThat(result).isEqualTo("This exception was not rethrown"); + assertThat(result.content()).isEqualTo("This exception was not rethrown"); } @Test diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTest.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTest.java index 4c3361d25e0..3f335e393db 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTest.java @@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.metadata.ToolMetadata; @@ -64,10 +65,10 @@ void testBiFunctionToolCall() { ToolContext toolContext = new ToolContext(Map.of("foo", "bar")); - String callResult = callback.call("\"test string param\"", toolContext); + ToolCallResult callResult = callback.call("\"test string param\"", toolContext); assertEquals("test string param", tool.calledValue.get()); - assertEquals("\"return value = test string param\"", callResult); + assertEquals("\"return value = test string param\"", callResult.content()); assertEquals(toolContext, tool.calledToolContext.get()); } @@ -82,10 +83,10 @@ void testFunctionToolCall() { ToolContext toolContext = new ToolContext(Map.of()); - String callResult = callback.call("\"test string param\"", toolContext); + ToolCallResult callResult = callback.call("\"test string param\"", toolContext); assertEquals("test string param", tool.calledValue.get()); - assertEquals("\"return value = test string param\"", callResult); + assertEquals("\"return value = test string param\"", callResult.content()); } @Test @@ -100,10 +101,10 @@ void testSupplierToolCall() { ToolContext toolContext = new ToolContext(Map.of()); - String callResult = callback.call("\"test string param\"", toolContext); + ToolCallResult callResult = callback.call("\"test string param\"", toolContext); assertEquals("not params", tool.calledValue.get()); - assertEquals("\"return value = \"", callResult); + assertEquals("\"return value = \"", callResult.content()); } @Test diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackExceptionHandlingTest.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackExceptionHandlingTest.java index a504f48ca8a..a3f9b24a40c 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackExceptionHandlingTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackExceptionHandlingTest.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.execution.ToolCallResult; import org.springframework.ai.tool.execution.ToolExecutionException; import static org.assertj.core.api.Assertions.assertThat; @@ -46,10 +47,10 @@ void testGenericListType() throws Exception { """; // Call the tool - String result = callback.call(toolInput); + ToolCallResult result = callback.call(toolInput); // Verify the result - assertThat(result).isEqualTo("3 strings processed: [one, two, three]"); + assertThat(result.content()).isEqualTo("3 strings processed: [one, two, three]"); // Verify String ivalidToolInput = """ diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java index 6e05fd80c59..5eaac86183e 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java @@ -25,6 +25,7 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; @@ -61,10 +62,10 @@ void testGenericListType() throws Exception { """; // Call the tool - String result = callback.call(toolInput); + ToolCallResult result = callback.call(toolInput); // Verify the result - assertThat(result).isEqualTo("3 strings processed: [one, two, three]"); + assertThat(result.content()).isEqualTo("3 strings processed: [one, two, three]"); } @Test @@ -95,10 +96,10 @@ void testGenericMapType() throws Exception { """; // Call the tool - String result = callback.call(toolInput); + ToolCallResult result = callback.call(toolInput); // Verify the result - assertThat(result).isEqualTo("3 entries processed: {one=1, two=2, three=3}"); + assertThat(result.content()).isEqualTo("3 entries processed: {one=1, two=2, three=3}"); } @Test @@ -132,10 +133,10 @@ void testNestedGenericType() throws Exception { """; // Call the tool - String result = callback.call(toolInput); + ToolCallResult result = callback.call(toolInput); // Verify the result - assertThat(result).isEqualTo("2 maps processed: [{a=1, b=2}, {c=3, d=4}]"); + assertThat(result.content()).isEqualTo("2 maps processed: [{a=1, b=2}, {c=3, d=4}]"); } @Test @@ -167,10 +168,10 @@ void testToolContextType() throws Exception { ToolContext toolContext = new ToolContext(Map.of("foo", "bar")); // Call the tool - String result = callback.call(toolInput, toolContext); + ToolCallResult result = callback.call(toolInput, toolContext); // Verify the result - assertThat(result).isEqualTo("1 entries processed {foo=bar}"); + assertThat(result.content()).isEqualTo("1 entries processed {foo=bar}"); } /** diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/DefaultToolCallingObservationConventionTests.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/DefaultToolCallingObservationConventionTests.java index 46727be7ab0..1dd4fdbfeb0 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/DefaultToolCallingObservationConventionTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/DefaultToolCallingObservationConventionTests.java @@ -24,6 +24,7 @@ import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; @@ -88,7 +89,7 @@ void shouldHaveHighCardinalityKeyValues() { ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder() .toolDefinition(ToolDefinition.builder().name("toolA").description("description").inputSchema("{}").build()) .toolCallArguments(toolCallInput) - .toolCallResult("Mission accomplished!") + .toolCallResult(ToolCallResult.builder().content("Mission accomplished!").build()) .build(); assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains( KeyValue.of(ToolCallingObservationDocumentation.HighCardinalityKeyNames.TOOL_DEFINITION_DESCRIPTION diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/ToolCallingContentObservationFilterTests.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/ToolCallingContentObservationFilterTests.java index c10a144b9e4..c7232a33dbf 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/ToolCallingContentObservationFilterTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/ToolCallingContentObservationFilterTests.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; @@ -46,7 +47,7 @@ void augmentContext() { var originalContext = ToolCallingObservationContext.builder() .toolDefinition(ToolDefinition.builder().name("toolA").description("description").inputSchema("{}").build()) .toolCallArguments("input") - .toolCallResult("result") + .toolCallResult(ToolCallResult.builder().content("result").build()) .build(); var augmentedContext = this.observationFilter.map(originalContext); @@ -61,7 +62,7 @@ void augmentContextWhenNullResult() { var originalContext = ToolCallingObservationContext.builder() .toolDefinition(ToolDefinition.builder().name("toolA").description("description").inputSchema("{}").build()) .toolCallArguments("input") - .toolCallResult("result") + .toolCallResult(ToolCallResult.builder().content("result").build()) .build(); var augmentedContext = this.observationFilter.map(originalContext); @@ -79,7 +80,7 @@ void whenToolCallArgumentsIsEmptyStringThenHighCardinalityKeyValueIsEmpty() { var originalContext = ToolCallingObservationContext.builder() .toolDefinition(ToolDefinition.builder().name("toolA").description("description").inputSchema("{}").build()) .toolCallArguments("") - .toolCallResult("result") + .toolCallResult(ToolCallResult.builder().content("result").build()) .build(); var augmentedContext = this.observationFilter.map(originalContext); @@ -94,7 +95,7 @@ void whenToolCallResultIsEmptyStringThenHighCardinalityKeyValueIsEmpty() { var originalContext = ToolCallingObservationContext.builder() .toolDefinition(ToolDefinition.builder().name("toolA").description("description").inputSchema("{}").build()) .toolCallArguments("input") - .toolCallResult("") + .toolCallResult(ToolCallResult.builder().content("").build()) .build(); var augmentedContext = this.observationFilter.map(originalContext); @@ -109,7 +110,7 @@ void whenFilterAppliedMultipleTimesThenIdempotent() { var originalContext = ToolCallingObservationContext.builder() .toolDefinition(ToolDefinition.builder().name("toolA").description("description").inputSchema("{}").build()) .toolCallArguments("input") - .toolCallResult("result") + .toolCallResult(ToolCallResult.builder().content("result").build()) .build(); var augmentedOnce = this.observationFilter.map(originalContext); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/ToolCallingObservationContextTests.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/ToolCallingObservationContextTests.java index abebcd7b419..9ef6eb3e0c7 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/ToolCallingObservationContextTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/observation/ToolCallingObservationContextTests.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResult; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -98,10 +99,10 @@ void whenToolCallResultIsNullThenReturnNull() { void whenToolCallResultIsEmptyStringThenReturnEmptyString() { var observationContext = ToolCallingObservationContext.builder() .toolDefinition(ToolDefinition.builder().name("toolA").description("description").inputSchema("{}").build()) - .toolCallResult("") + .toolCallResult(ToolCallResult.builder().content("").build()) .build(); assertThat(observationContext).isNotNull(); - assertThat(observationContext.getToolCallResult()).isEqualTo(""); + assertThat(observationContext.getToolCallResult().content()).isEqualTo(""); } @Test