Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
distribution = "2201.12.0"
org = "ballerina"
name = "ai"
version = "1.9.0"
version = "1.9.1"
license = ["Apache-2.0"]
authors = ["Ballerina"]
keywords = ["AI/Agent", "Cost/Freemium", "Agent", "Vendor/N/A", "Area/AI", "Type/Connector"]
Expand All @@ -19,14 +19,14 @@ graalvmCompatible = true
[[platform.java21.dependency]]
groupId = "io.ballerina.stdlib"
artifactId = "ai-native"
version = "1.9.0"
path = "../native/build/libs/ai-native-1.9.0.jar"
version = "1.9.1"
path = "../native/build/libs/ai-native-1.9.1-SNAPSHOT.jar"

[[platform.java21.dependency]]
groupId = "io.ballerina.stdlib"
artifactId = "ai-native"
version = "1.9.0"
path = "../native/build/libs/ai-native-1.9.0-tests.jar"
version = "1.9.1"
path = "../native/build/libs/ai-native-1.9.1-SNAPSHOT-tests.jar"
testOnly = true

[[platform.java21.dependency]]
Expand Down
2 changes: 1 addition & 1 deletion ballerina/CompilerPlugin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ id = "ai-compiler-plugin"
class = "io.ballerina.stdlib.ai.plugin.AiCompilerPlugin"

[[dependency]]
path = "../compiler-plugin/build/libs/ai-compiler-plugin-1.9.0.jar"
path = "../compiler-plugin/build/libs/ai-compiler-plugin-1.9.1-SNAPSHOT.jar"

[[dependency]]
path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.2.jar"
33 changes: 31 additions & 2 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,28 @@ distribution-version = "2201.12.0"
[[package]]
org = "ballerina"
name = "ai"
version = "1.9.0"
version = "1.9.1"
dependencies = [
{org = "ballerina", name = "cache"},
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "crypto"},
{org = "ballerina", name = "data.jsondata"},
{org = "ballerina", name = "data.xmldata"},
{org = "ballerina", name = "file"},
{org = "ballerina", name = "http"},
{org = "ballerina", name = "io"},
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "jwt"},
{org = "ballerina", name = "lang.array"},
{org = "ballerina", name = "lang.regexp"},
{org = "ballerina", name = "lang.runtime"},
{org = "ballerina", name = "lang.string"},
{org = "ballerina", name = "log"},
{org = "ballerina", name = "math.vector"},
{org = "ballerina", name = "mcp"},
{org = "ballerina", name = "mime"},
{org = "ballerina", name = "observe"},
{org = "ballerina", name = "random"},
{org = "ballerina", name = "test"},
{org = "ballerina", name = "time"},
{org = "ballerina", name = "url"},
Expand Down Expand Up @@ -61,6 +66,9 @@ dependencies = [
{org = "ballerina", name = "task"},
{org = "ballerina", name = "time"}
]
modules = [
{org = "ballerina", packageName = "cache", moduleName = "cache"}
]

[[package]]
org = "ballerina"
Expand All @@ -81,6 +89,9 @@ dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "time"}
]
modules = [
{org = "ballerina", packageName = "crypto", moduleName = "crypto"}
]

[[package]]
org = "ballerina"
Expand Down Expand Up @@ -187,6 +198,9 @@ dependencies = [
{org = "ballerina", name = "log"},
{org = "ballerina", name = "time"}
]
modules = [
{org = "ballerina", packageName = "jwt", moduleName = "jwt"}
]

[[package]]
org = "ballerina"
Expand Down Expand Up @@ -271,6 +285,9 @@ dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "lang.regexp"}
]
modules = [
{org = "ballerina", packageName = "lang.string", moduleName = "lang.string"}
]

[[package]]
org = "ballerina"
Expand Down Expand Up @@ -305,7 +322,7 @@ modules = [
[[package]]
org = "ballerina"
name = "mcp"
version = "1.0.2"
version = "1.0.3"
dependencies = [
{org = "ballerina", name = "http"},
{org = "ballerina", name = "jballerina.java"},
Expand Down Expand Up @@ -363,6 +380,18 @@ dependencies = [
{org = "ballerina", name = "jballerina.java"}
]

[[package]]
org = "ballerina"
name = "random"
version = "1.7.0"
dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "time"}
]
modules = [
{org = "ballerina", packageName = "random", moduleName = "random"}
]

[[package]]
org = "ballerina"
name = "task"
Expand Down
149 changes: 102 additions & 47 deletions ballerina/agent-utils.bal
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ai.observe;

import ballerina/cache;
import ballerina/io;
import ballerina/log;
import ballerina/time;
Expand Down Expand Up @@ -85,9 +86,11 @@ public type ToolOutput record {|

type BaseAgent distinct isolated object {
ModelProvider model;
ToolStore toolStore;
ToolManager toolManager;
Memory memory;
boolean stateless;
cache:Cache tokenManager;
AuthConfig? auth;

# Parse the llm response and extract the tool to be executed.
#
Expand Down Expand Up @@ -115,6 +118,9 @@ class Executor {
private final BaseAgent agent;
# Contains the current execution progress for the agent and the query
public ExecutionProgress progress;
private string agentId = "";
private boolean & readonly isAuthEnabled = false;


# Initialize the executor with the agent and the query.
#
Expand All @@ -126,6 +132,11 @@ class Executor {
self.sessionId = sessionId;
self.agent = agent;
self.progress = progress;
AuthConfig? auth = agent.auth;
if auth is AuthConfig {
self.isAuthEnabled = true;
self.agentId = auth.agentId;
}
}

# Checks whether agent has more steps to execute.
Expand All @@ -140,6 +151,12 @@ class Executor {
# + return - generated LLM response during the reasoning or an error if the reasoning fails
public isolated function reason() returns json|Error {
if self.isCompleted {
if self.isAuthEnabled {
log:printError("Task is already completed. No more reasoning is needed.",
executionId = self.progress.executionId,
agentId = self.agentId
);
}
return error TaskCompletedError("Task is already completed. No more reasoning is needed.");
}
log:printDebug("LLM reasoning started",
Expand Down Expand Up @@ -181,56 +198,94 @@ class Executor {
if toolCallId is string {
span.addId(toolCallId);
}
string? toolDescription = self.agent.toolStore.getToolDescription(toolName);
ToolManager toolManager = self.agent.toolManager;
string? toolDescription = toolManager.getToolDescription(toolName);
if toolDescription is string {
span.addDescription(toolDescription);

}
span.addType(self.agent.toolStore.isMcpTool(toolName) ? observe:EXTENTION : observe:FUNCTION);
boolean isMcpTool = toolManager.isMcpTool(toolName);
span.addType(isMcpTool ? observe:EXTENTION : observe:FUNCTION);
span.addArguments(parsedOutput.arguments);

ToolOutput|ToolExecutionError|LlmInvalidGenerationError output = self.agent.toolStore.execute(parsedOutput,
self.progress.context);
if output is Error {
if output is ToolNotFoundError {
observation = "Tool is not found. Please check the tool name and retry.";
} else if output is ToolInvalidInputError {
observation = "Tool execution failed due to invalid inputs. Retry with correct inputs.";
} else {
observation = "Tool execution failed. Retry with correct inputs.";
}
observation = string `${observation.toString()} <detail>${output.toString()}</detail>`;
LlmInvalidGenerationError|ToolExecutionError? validateRes = toolManager.
validateTool(parsedOutput, self.agent.auth, self.agent.tokenManager,
self.progress.context, isMcpTool);
if validateRes is Error {
log:printError("Tool validation failed",
executionId = self.progress.executionId,
agentId = self.agentId,
sessionId = self.sessionId,
toolName = toolName,
'error = validateRes
);
observation = "Tool extraction failed due to tool validation";
executionResult = {
llmResponse,
'error: output,
observation: observation.toString()
'error: validateRes,
observation: "Tool extraction failed due to tool validation"
};

log:printDebug("Tool execution resulted in error",
executionId = self.progress.executionId,
observation = observation.toString(),
sessionId = self.sessionId,
toolName = toolName
);

Error toolExecutionError = error Error(observation.toString(), details = {parsedOutput});
span.close(toolExecutionError);
} else {
anydata|error value = output.value;
observation = value is error ? value.toString() : value;
log:printDebug("Tool execution successful",
executionId = self.progress.executionId,
sessionId = self.sessionId,
toolName = toolName,
output = observation
);
executionResult = {
tool: parsedOutput,
observation: value
};

span.addOutput(observation);
span.close();
ToolOutput|ToolExecutionError|LlmInvalidGenerationError output = toolManager.execute(parsedOutput,
self.progress.context);
if output is Error {
if output is ToolNotFoundError {
observation = "Tool is not found. Please check the tool name and retry.";
} else if output is ToolInvalidInputError {
observation = "Tool execution failed due to invalid inputs. Retry with correct inputs.";
} else {
observation = "Tool execution failed. Retry with correct inputs.";
}
observation = string `${observation.toString()} <detail>${output.toString()}</detail>`;
executionResult = {
llmResponse,
'error: output,
observation: observation.toString()
};

if self.isAuthEnabled {
log:printError("Tool execution resulted in error",
executionId = self.progress.executionId,
agentId = self.agentId,
observation = observation.toString(),
sessionId = self.sessionId,
toolName = toolName
);
} else {
log:printDebug("Tool execution resulted in error",
executionId = self.progress.executionId,
observation = observation.toString(),
sessionId = self.sessionId,
toolName = toolName
);
}

Error toolExecutionError = error Error(observation.toString(), details = {parsedOutput});
span.close(toolExecutionError);
} else {
anydata|error value = output.value;
observation = value is error ? value.toString() : value;
if self.isAuthEnabled {
log:printInfo("Tool execution successful",
executionId = self.progress.executionId,
agentId = self.agentId,
sessionId = self.sessionId,
toolName = toolName
);
} else {
log:printDebug("Tool execution successful",
executionId = self.progress.executionId,
sessionId = self.sessionId,
toolName = toolName,
output = observation
);
}
executionResult = {
tool: parsedOutput,
observation: value
};

span.addOutput(observation);
span.close();
}
}
} else {
log:printDebug("Failed to parse LLM response as valid tool or chat",
Expand Down Expand Up @@ -303,9 +358,9 @@ isolated function run(BaseAgent agent, string instruction, string query, int max
executionId = executionId,
sessionId = sessionId,
maxIterations = maxIter,
tools = agent.toolStore.tools.toString(),
tools = agent.toolManager.tools.toString(),
isStateless = agent.stateless
);
);

(ExecutionResult|ExecutionError|Error)[] steps = [];

Expand Down Expand Up @@ -466,7 +521,7 @@ isolated function getOutputOfIteration(ExecutionResult|LlmChatResponse|Execution
}

isolated function buildCurrentIterationHistory(ExecutionProgress progress,
ChatMessage[] conversationHistoryUpToCurrentUserQuery) returns ChatMessage[] {
ChatMessage[] conversationHistoryUpToCurrentUserQuery) returns ChatMessage[] {
ChatMessage[] messages = createFunctionCallMessages(progress);
messages.unshift(...conversationHistoryUpToCurrentUserQuery);
return messages;
Expand All @@ -493,7 +548,7 @@ isolated function getObservationString(anydata|error observation) returns string
#
# + agent - Agent instance
# + return - Array of tools registered with the agent
public isolated function getTools(Agent agent) returns Tool[] => agent.functionCallAgent.toolStore.tools.toArray();
public isolated function getTools(Agent agent) returns Tool[] => agent.functionCallAgent.toolManager.tools.toArray();

isolated function updateMemory(Memory memory, string sessionId, ChatMessage[] messages) {
error? updationStation = memory.update(sessionId, messages);
Expand Down
Loading