Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolCallResult;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.boot.autoconfigure.AutoConfigurations;
Expand Down Expand Up @@ -210,8 +211,8 @@ public ToolDefinition getToolDefinition() {
}

@Override
public String call(String toolInput) {
return "~~ not implemented ~~";
public ToolCallResult call(String toolInput) {
return ToolCallResult.builder().content("~~ not implemented ~~").build();
}
} };
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolCallResult;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
Expand Down Expand Up @@ -184,8 +185,8 @@ void testMcpServerClientIntegrationWithIncompleteSchemaSyncTool() {
assertThat(schemaMap.get("additionalProperties")).isEqualTo(false);

// Test that the callback can be called successfully
String result = toolCallback.call("{}");
assertThat(result).isNotNull().contains("Current time:");
ToolCallResult result = toolCallback.call("{}");
assertThat(result.content()).isNotNull().contains("Current time:");
});

stopHttpServer(httpServer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolCallResult;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.function.FunctionToolCallback;
Expand Down Expand Up @@ -162,8 +163,8 @@ void throwExceptionOnErrorDefault() {
var exception = new ToolExecutionException(toolDefinition, cause);

// Default behavior should not throw exception
String result = toolExecutionExceptionProcessor.process(exception);
assertThat(result).isEqualTo("Test error");
ToolCallResult result = toolExecutionExceptionProcessor.process(exception);
assertThat(result.content()).isEqualTo("Test error");
});
}

Expand Down Expand Up @@ -364,8 +365,8 @@ public ToolExecutionExceptionProcessor toolExecutionExceptionProcessor() {
static class CustomToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor {

@Override
public String process(ToolExecutionException exception) {
return "Custom error handling";
public ToolCallResult process(ToolExecutionException exception) {
return ToolCallResult.builder().content("Custom error handling").build();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolCallResult;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -100,12 +101,12 @@ public String getOriginalToolName() {
}

@Override
public String call(String toolCallInput) {
public ToolCallResult call(String toolCallInput) {
return this.call(toolCallInput, null);
}

@Override
public String call(String toolCallInput, @Nullable ToolContext toolContext) {
public ToolCallResult call(String toolCallInput, @Nullable ToolContext toolContext) {

// Handle the possible null parameter situation in streaming mode.
if (!StringUtils.hasText(toolCallInput)) {
Expand Down Expand Up @@ -143,7 +144,7 @@ public String call(String toolCallInput, @Nullable ToolContext toolContext) {
throw new ToolExecutionException(this.getToolDefinition(),
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
return ToolCallResult.builder().content(ModelOptionsUtils.toJsonString(response.content())).build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolCallResult;
import org.springframework.ai.util.json.schema.JsonSchemaUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
Expand Down Expand Up @@ -257,14 +258,15 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal

return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> {
try {
String callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request.arguments()),
ToolCallResult callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request.arguments()),
new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchangeOrContext)));
if (mimeType != null && mimeType.toString().startsWith("image")) {
McpSchema.Annotations annotations = new McpSchema.Annotations(List.of(Role.ASSISTANT), null);
return new McpSchema.CallToolResult(
List.of(new McpSchema.ImageContent(annotations, callResult, mimeType.toString())), false);
List.of(new McpSchema.ImageContent(annotations, callResult.content(), mimeType.toString())),
false);
}
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false);
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult.content())), false);
}
catch (Exception e) {
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolCallResult;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -101,12 +102,12 @@ public String getOriginalToolName() {
}

@Override
public String call(String toolCallInput) {
public ToolCallResult call(String toolCallInput) {
return this.call(toolCallInput, null);
}

@Override
public String call(String toolCallInput, @Nullable ToolContext toolContext) {
public ToolCallResult call(String toolCallInput, @Nullable ToolContext toolContext) {

// Handle the possible null parameter situation in streaming mode.
if (!StringUtils.hasText(toolCallInput)) {
Expand Down Expand Up @@ -142,7 +143,7 @@ public String call(String toolCallInput, @Nullable ToolContext toolContext) {
throw new ToolExecutionException(this.getToolDefinition(),
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
return ToolCallResult.builder().content(ModelOptionsUtils.toJsonString(response.content())).build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolCallResult;
import org.springframework.ai.tool.execution.ToolExecutionException;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -98,10 +99,10 @@ void callShouldSucceedWithValidInput() {
.prefixedToolName("prefixed_testTool")
.build();

String result = callback.call("{\"param\":\"value\"}");
ToolCallResult result = callback.call("{\"param\":\"value\"}");

// Assert
assertThat(result).contains("Success response");
assertThat(result.content()).contains("Success response");

// Verify the correct tool name was used in the request
ArgumentCaptor<McpSchema.CallToolRequest> requestCaptor = ArgumentCaptor
Expand All @@ -128,10 +129,10 @@ void callShouldHandleNullInput() {
.prefixedToolName("testTool")
.build();

String result = callback.call(null);
ToolCallResult result = callback.call(null);

// Assert
assertThat(result).contains("Success with empty input");
assertThat(result.content()).contains("Success with empty input");

// Verify empty JSON object was used
ArgumentCaptor<McpSchema.CallToolRequest> requestCaptor = ArgumentCaptor
Expand All @@ -156,10 +157,10 @@ void callShouldHandleEmptyInput() {
.prefixedToolName("testTool")
.build();

String result = callback.call("");
ToolCallResult result = callback.call("");

// Assert
assertThat(result).contains("Success with empty input");
assertThat(result.content()).contains("Success with empty input");

// Verify empty JSON object was used
ArgumentCaptor<McpSchema.CallToolRequest> requestCaptor = ArgumentCaptor
Expand Down Expand Up @@ -187,10 +188,10 @@ void callShouldIncludeToolContext() {
.prefixedToolName("testTool")
.build();

String result = callback.call("{\"param\":\"value\"}", toolContext);
ToolCallResult result = callback.call("{\"param\":\"value\"}", toolContext);

// Assert
assertThat(result).contains("Success with context");
assertThat(result.content()).contains("Success with context");

// Verify the context was included in the request
ArgumentCaptor<McpSchema.CallToolRequest> requestCaptor = ArgumentCaptor
Expand Down Expand Up @@ -323,11 +324,11 @@ void callShouldHandleComplexJsonResponse() {
.prefixedToolName("testTool")
.build();

String result = callback.call("{\"input\":\"test\"}");
ToolCallResult result = callback.call("{\"input\":\"test\"}");

// Assert
assertThat(result).contains("Part 1");
assertThat(result).contains("Part 2");
assertThat(result.content()).contains("Part 1");
assertThat(result.content()).contains("Part 2");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.execution.ToolCallResult;
import org.springframework.ai.tool.execution.ToolExecutionException;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -93,7 +94,7 @@ void callShouldHandleJsonInputAndOutput() {
.toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter())
.build();

String response = callback.call("{\"param\":\"value\"}");
ToolCallResult response = callback.call("{\"param\":\"value\"}");

assertThat(response).isNotNull();
}
Expand All @@ -111,7 +112,7 @@ void callShouldHandleToolContext() {
.toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter())
.build();

String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar")));
ToolCallResult response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar")));

assertThat(response).isNotNull();
}
Expand All @@ -130,16 +131,16 @@ void callShouldHandleNullOrEmptyInput() {
.build();

// Test with null input
String responseNull = callback.call(null);
assertThat(responseNull).isEqualTo("[]");
ToolCallResult responseNull = callback.call(null);
assertThat(responseNull.content()).isEqualTo("[]");

// Test with empty string input
String responseEmpty = callback.call("");
assertThat(responseEmpty).isEqualTo("[]");
ToolCallResult responseEmpty = callback.call("");
assertThat(responseEmpty.content()).isEqualTo("[]");

// Test with whitespace-only input
String responseWhitespace = callback.call(" ");
assertThat(responseWhitespace).isEqualTo("[]");
ToolCallResult responseWhitespace = callback.call(" ");
assertThat(responseWhitespace.content()).isEqualTo("[]");
}

@Test
Expand Down Expand Up @@ -197,9 +198,9 @@ void callShouldHandleEmptyResponse() {
.toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter())
.build();

String response = callback.call("{\"param\":\"value\"}");
ToolCallResult response = callback.call("{\"param\":\"value\"}");

assertThat(response).isEqualTo("[]");
assertThat(response.content()).isEqualTo("[]");
}

@Test
Expand All @@ -218,10 +219,10 @@ void callShouldHandleMultipleContentItems() {
.toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter())
.build();

String response = callback.call("{\"param\":\"value\"}");
ToolCallResult response = callback.call("{\"param\":\"value\"}");

assertThat(response).isNotNull();
assertThat(response).isEqualTo("[{\"text\":\"First content\"},{\"text\":\"Second content\"}]");
assertThat(response.content()).isEqualTo("[{\"text\":\"First content\"},{\"text\":\"Second content\"}]");
}

@Test
Expand All @@ -239,10 +240,10 @@ void callShouldHandleNonTextContent() {
.toolContextToMcpMetaConverter(ToolContextToMcpMetaConverter.defaultConverter())
.build();

String response = callback.call("{\"param\":\"value\"}");
ToolCallResult response = callback.call("{\"param\":\"value\"}");

assertThat(response).isNotNull();
assertThat(response).isEqualTo("[{\"data\":\"base64data\",\"mimeType\":\"image/png\"}]");
assertThat(response.content()).isEqualTo("[{\"data\":\"base64data\",\"mimeType\":\"image/png\"}]");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolCallResult;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -321,7 +322,7 @@ private ToolCallback createMockToolCallback(String name, String result) {
.inputSchema("{}")
.build();
when(callback.getToolDefinition()).thenReturn(definition);
when(callback.call(anyString(), any())).thenReturn(result);
when(callback.call(anyString(), any())).thenReturn(ToolCallResult.builder().content(result).build());
return callback;
}

Expand Down
Loading