diff --git a/ballerina/Ballerina.toml b/ballerina/Ballerina.toml index 4a980a6..2e6d8cc 100644 --- a/ballerina/Ballerina.toml +++ b/ballerina/Ballerina.toml @@ -7,7 +7,7 @@ icon="icon.png" name = "ai.azure" org = "ballerinax" repository = "https://github.com/ballerina-platform/module-ballerinax-ai.azure" -version = "1.3.0" +version = "1.4.0" [platform.java21] graalvmCompatible = true @@ -15,5 +15,5 @@ graalvmCompatible = true [[platform.java21.dependency]] groupId = "io.ballerina.lib" artifactId = "ai.azure-native" -version = "1.3.0" -path = "../native/build/libs/ai.azure-native-1.3.0-SNAPSHOT.jar" +version = "1.4.0" +path = "../native/build/libs/ai.azure-native-1.4.0-SNAPSHOT.jar" diff --git a/ballerina/CompilerPlugin.toml b/ballerina/CompilerPlugin.toml index a9d5849..f3c4f2f 100644 --- a/ballerina/CompilerPlugin.toml +++ b/ballerina/CompilerPlugin.toml @@ -3,7 +3,7 @@ id = "ai-azure-compiler-plugin" class = "io.ballerina.lib.ai.azure.AiAzureCompilerPlugin" [[dependency]] -path = "../compiler-plugin/build/libs/ai.azure-compiler-plugin-1.3.0-SNAPSHOT.jar" +path = "../compiler-plugin/build/libs/ai.azure-compiler-plugin-1.4.0-SNAPSHOT.jar" [[dependency]] path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar" diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 54b2eb1..b519f0c 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -5,12 +5,12 @@ [ballerina] dependencies-toml-version = "2" -distribution-version = "2201.12.9" +distribution-version = "2201.12.0" [[package]] org = "ballerina" name = "ai" -version = "1.5.4" +version = "1.7.0" dependencies = [ {org = "ballerina", name = "constraint"}, {org = "ballerina", name = "data.jsondata"}, @@ -26,6 +26,7 @@ dependencies = [ {org = "ballerina", name = "math.vector"}, {org = "ballerina", name = "mcp"}, {org = "ballerina", name = "mime"}, + {org = "ballerina", name = "observe"}, {org = "ballerina", name = "time"}, {org = "ballerina", name = "url"}, {org = "ballerina", name = "uuid"}, @@ -33,7 +34,8 @@ dependencies = [ ] modules = [ {org = "ballerina", packageName = "ai", moduleName = "ai"}, - {org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"} + {org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"}, + {org = "ballerina", packageName = "ai", moduleName = "ai.observe"} ] [[package]] @@ -281,7 +283,7 @@ version = "1.2.0" [[package]] org = "ballerina" name = "mcp" -version = "1.0.1" +version = "1.0.2" dependencies = [ {org = "ballerina", name = "http"}, {org = "ballerina", name = "jballerina.java"}, @@ -315,7 +317,7 @@ dependencies = [ [[package]] org = "ballerina" name = "observe" -version = "1.5.1" +version = "1.6.0" dependencies = [ {org = "ballerina", name = "jballerina.java"} ] @@ -407,7 +409,7 @@ dependencies = [ [[package]] org = "ballerinax" name = "ai.azure" -version = "1.3.0" +version = "1.4.0" dependencies = [ {org = "ballerina", name = "ai"}, {org = "ballerina", name = "constraint"}, diff --git a/ballerina/embedding_provider.bal b/ballerina/embedding_provider.bal index c241ddb..ab1966f 100644 --- a/ballerina/embedding_provider.bal +++ b/ballerina/embedding_provider.bal @@ -15,6 +15,7 @@ // under the License. import ballerina/ai; +import ballerina/ai.observe; import ballerinax/azure.openai.embeddings; # EmbeddingProvider provides an interface for interacting with Azure OpenAI Embedding Models. @@ -76,19 +77,39 @@ public distinct isolated client class EmbeddingProvider { # + chunk - The `ai:Chunk` containing the content to embed # + return - The resulting `ai:Embedding` on success; otherwise, returns an `ai:Error` isolated remote function embed(ai:Chunk chunk) returns ai:Embedding|ai:Error { + observe:EmbeddingSpan span = observe:createEmbeddingSpan(self.deploymentId); + span.addProvider("azure.ai.openai"); + if chunk !is ai:TextDocument|ai:TextChunk { - return error ai:Error("Unsupported document type. only 'ai:TextDocument' or 'ai:TextChunk' is supported"); + ai:Error err = error ai:Error("Unsupported chunk type. only 'ai:TextDocument|ai:TextChunk' is supported"); + span.close(err); + return err; } + do { + span.addInputContent(chunk.content); embeddings:Inline_response_200 response = check self.embeddingsClient->/deployments/[self.deploymentId]/embeddings.post( apiVersion = self.apiVersion, payload = { input: chunk.content } ); - return check response.data[0].embedding.cloneWithType(); + + span.addResponseModel(response.model); + span.addInputTokenCount(response.usage.prompt_tokens); + if response.data.length() == 0 { + ai:Error err = error("No embeddings generated for the provided chunk"); + span.close(err); + return err; + } + + ai:Embedding embedding = check response.data[0].embedding.cloneWithType(); + span.close(); + return embedding; } on fail error e { - return error ai:Error("Unable to obtain embedding for the provided chunk", e); + ai:Error err = error ai:Error("Unable to obtain embedding for the provided chunk", e); + span.close(err); + return err; } } @@ -97,10 +118,18 @@ public distinct isolated client class EmbeddingProvider { # + chunks - The array of chunks to be converted into embeddings # + return - An array of embeddings on success, or an `ai:Error` isolated remote function batchEmbed(ai:Chunk[] chunks) returns ai:Embedding[]|ai:Error { + observe:EmbeddingSpan span = observe:createEmbeddingSpan(self.deploymentId); + span.addProvider("azure.ai.openai"); + if !chunks.every(chunk => chunk is ai:TextChunk|ai:TextDocument) { - return error("Unsupported chunk type. only 'ai:TextChunk[]|ai:TextDocument[]' is supported"); + ai:Error err = error("Unsupported chunk type. only 'ai:TextChunk[]|ai:TextDocument[]' is supported"); + span.close(err); + return err; } do { + string[] input = chunks.map(chunk => chunk.content.toString()); + span.addInputContent(input); + embeddings:InputItemsString[] inputItems = from ai:Chunk chunk in chunks select check chunk.content.cloneWithType(); embeddings:Inline_response_200 response = check self.embeddingsClient->/deployments/[self.deploymentId]/embeddings.post( @@ -109,11 +138,16 @@ public distinct isolated client class EmbeddingProvider { input: inputItems } ); - return - from embeddings:Inline_response_200_data data in response.data - select check data.embedding.cloneWithType(); + + span.addInputTokenCount(response.usage.prompt_tokens); + ai:Embedding[] embeddings = from embeddings:Inline_response_200_data data in response.data + select check data.embedding.cloneWithType(); + span.close(); + return embeddings; } on fail error e { - return error ai:Error("Unable to obtain embedding for the provided document", e); + ai:Error err = error("Unable to obtain embedding for the provided document", e); + span.close(err); + return err; } } } diff --git a/ballerina/model-provider.bal b/ballerina/model-provider.bal index 7b5cd0f..4eea2d4 100644 --- a/ballerina/model-provider.bal +++ b/ballerina/model-provider.bal @@ -15,6 +15,7 @@ // under the License. import ballerina/ai; +import ballerina/ai.observe; import ballerina/jballerina.java; import ballerinax/azure.openai.chat; @@ -89,6 +90,21 @@ public isolated client class OpenAiModelProvider { # + return - Function to be called, chat response or an error in-case of failures isolated remote function chat(ai:ChatMessage[]|ai:ChatUserMessage messages, ai:ChatCompletionFunctions[] tools, string? stop = ()) returns ai:ChatAssistantMessage|ai:Error { + observe:ChatSpan span = observe:createChatSpan(self.deploymentId); + span.addProvider("azure.ai.openai"); + span.addOutputType(observe:TEXT); + if stop is string { + span.addStopSequence(stop); + } + span.addTemperature(self.temperature); + json|ai:Error jsonMsg = check convertMessageToJson(messages); + if jsonMsg is ai:Error { + ai:Error err = error("Error while transforming input", jsonMsg); + span.close(err); + return err; + } + span.addInputMessages(jsonMsg); + chat:CreateChatCompletionRequest request = { stop, messages: check self.mapToChatCompletionRequestMessage(messages), @@ -97,11 +113,14 @@ public isolated client class OpenAiModelProvider { }; if tools.length() > 0 { request.functions = tools; + span.addTools(tools); } chat:CreateChatCompletionResponse|error response = self.llmClient->/deployments/[self.deploymentId]/chat/completions.post(self.apiVersion, request); if response is error { - return error ai:LlmConnectionError("Error while connecting to the model", response); + ai:Error err = error ai:LlmConnectionError("Error while connecting to the model", response); + span.close(err); + return err; } record {| @@ -113,24 +132,52 @@ public isolated client class OpenAiModelProvider { |}[]? choices = response.choices; if choices is () || choices.length() == 0 { - return error ai:LlmInvalidResponseError("Empty response from the model when using function call API"); + ai:Error err = error ai:LlmInvalidResponseError("Empty response from the model when using function call API"); + span.close(err); + return err; + } + + string|int? responseId = response.id; + if responseId is string { + span.addResponseId(responseId); + } + int? inputTokens = response.usage?.prompt_tokens; + if inputTokens is int { + span.addInputTokenCount(inputTokens); + } + int? outputTokens = response.usage?.completion_tokens; + if outputTokens is int { + span.addOutputTokenCount(outputTokens); + } + string? finishReason = choices[0].finish_reason; + if finishReason is string { + span.addFinishReason(finishReason); } + chat:ChatCompletionResponseMessage? message = choices[0].message; ai:ChatAssistantMessage chatAssistantMessage = {role: ai:ASSISTANT, content: message?.content}; chat:ChatCompletionFunctionCall? functionCall = message?.function_call; - if functionCall is chat:ChatCompletionFunctionCall { - chatAssistantMessage.toolCalls = [check self.mapToFunctionCall(functionCall)]; + if functionCall is () { + span.addOutputMessages(chatAssistantMessage); + span.close(); + return chatAssistantMessage; + } + ai:FunctionCall|ai:Error toolCall = check self.mapToFunctionCall(functionCall); + if toolCall is ai:Error { + span.close(toolCall); + return toolCall; } + chatAssistantMessage.toolCalls = [toolCall]; return chatAssistantMessage; } # Sends a chat request to the model and generates a value that belongs to the type # corresponding to the type descriptor argument. - # + # # + prompt - The prompt to use in the chat messages # + td - Type descriptor specifying the expected return type format # + return - Generates a value that belongs to the type, or an error if generation fails - isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc td = <>) + isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc td = <>) returns td|ai:Error = @java:Method { 'class: "io.ballerina.lib.ai.azure.Generator" } external; @@ -158,7 +205,7 @@ public isolated client class OpenAiModelProvider { assistantMessage["content"] = message?.content; } chatCompletionRequestMessages.push(assistantMessage); - } else if message is ai:ChatFunctionMessage { + } else { chatCompletionRequestMessages.push(message); } } @@ -233,3 +280,14 @@ isolated function getChatMessageStringContent(ai:Prompt|string prompt) returns s } return promptStr.trim(); } + +isolated function convertMessageToJson(ai:ChatMessage[]|ai:ChatMessage messages) returns json|ai:Error { + if messages is ai:ChatMessage[] { + return messages.'map(msg => msg is ai:ChatUserMessage|ai:ChatSystemMessage ? check convertMessageToJson(msg) : msg); + } + if messages is ai:ChatUserMessage|ai:ChatSystemMessage { + + } + return messages !is ai:ChatUserMessage|ai:ChatSystemMessage ? messages : + {role: messages.role, content: check getChatMessageStringContent(messages.content), name: messages.name}; +} diff --git a/ballerina/provider_utils.bal b/ballerina/provider_utils.bal index 811fd90..4187542 100644 --- a/ballerina/provider_utils.bal +++ b/ballerina/provider_utils.bal @@ -15,6 +15,7 @@ // under the License. import ballerina/ai; +import ballerina/ai.observe; import ballerina/constraint; import ballerina/lang.array; import ballerinax/azure.openai.chat; @@ -105,17 +106,22 @@ isolated function getGetResultsToolChoice() returns chat:ChatCompletionNamedTool } }; -isolated function getGetResultsTool(map parameters) returns chat:ChatCompletionTool[]|error => - [ - { - 'type: FUNCTION, - 'function: { - name: GET_RESULTS_TOOL, - parameters: check parameters.cloneWithType(), - description: "Tool to call with the response from a large language model (LLM) for a user prompt." - } +isolated function getGetResultsTool(map parameters) returns chat:ChatCompletionTool[]|ai:Error { + chat:ChatCompletionFunctionParameters|error toolParam = parameters.ensureType(); + if toolParam is error { + return error("Error in generated schema: " + toolParam.message()); } -]; + return [ + { + 'type: FUNCTION, + 'function: { + name: GET_RESULTS_TOOL, + parameters: toolParam, + description: "Tool to call with the response from a large language model (LLM) for a user prompt." + } + } + ]; +} isolated function generateChatCreationContent(ai:Prompt prompt) returns DocumentContentPart[]|ai:Error { string[] & readonly strings = prompt.strings; @@ -234,11 +240,20 @@ isolated function handleParseResponseError(error chatResponseError) returns erro isolated function generateLlmResponse(chat:Client llmClient, string deploymentId, string apiVersion, decimal temperature, int maxTokens, ai:Prompt prompt, typedesc expectedResponseTypedesc) returns anydata|ai:Error { - DocumentContentPart[] content = check generateChatCreationContent(prompt); - ResponseSchema ResponseSchema = check getExpectedResponseSchema(expectedResponseTypedesc); - chat:ChatCompletionTool[]|error tools = getGetResultsTool(ResponseSchema.schema); - if tools is error { - return error("Error in generated schema: " + tools.message()); + observe:GenerateContentSpan span = observe:createGenerateContentSpan(deploymentId); + span.addTemperature(temperature); + span.addProvider("azure.ai.openai"); + + DocumentContentPart[] content; + ResponseSchema responseSchema; + chat:ChatCompletionTool[] tools; + do { + content = check generateChatCreationContent(prompt); + responseSchema = check getExpectedResponseSchema(expectedResponseTypedesc); + tools = check getGetResultsTool(responseSchema.schema); + } on fail ai:Error err { + span.close(err); + return err; } chat:CreateChatCompletionRequest request = { @@ -253,13 +268,44 @@ isolated function generateLlmResponse(chat:Client llmClient, string deploymentId max_tokens: maxTokens, tool_choice: getGetResultsToolChoice() }; + span.addInputMessages(request.messages.toJson()); chat:CreateChatCompletionResponse|error response = llmClient->/deployments/[deploymentId]/chat/completions.post(apiVersion, request); if response is error { - return error("LLM call failed: " + response.message(), cause = response.cause(), detail = response.detail()); + ai:Error err = error("LLM call failed: " + response.message(), cause = response.cause(), detail = response.detail()); + span.close(err); + return err; } + string? responseId = response.id; + if responseId is string { + span.addResponseId(responseId); + } + int? inputTokens = response.usage?.prompt_tokens; + if inputTokens is int { + span.addInputTokenCount(inputTokens); + } + int? outputTokens = response.usage?.completion_tokens; + if outputTokens is int { + span.addOutputTokenCount(outputTokens); + } + + anydata|ai:Error result = ensureAnydataResult(response, expectedResponseTypedesc, + responseSchema.isOriginallyJsonObject, span); + if result is ai:Error { + span.close(result); + return result; + } + span.addOutputMessages(result.toJson()); + span.addOutputType(observe:JSON); + span.close(); + return result; +} + +isolated function ensureAnydataResult(chat:CreateChatCompletionResponse response, + typedesc expectedResponseTypedesc, boolean isOriginallyJsonObject, + observe:GenerateContentSpan span) returns anydata|ai:Error { record { chat:ChatCompletionResponseMessage message?; chat:ContentFilterChoiceResults content_filter_results?; @@ -276,6 +322,10 @@ isolated function generateLlmResponse(chat:Client llmClient, string deploymentId if toolCalls is () || toolCalls.length() == 0 { return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM); } + string? finishReason = choices[0].finish_reason; + if finishReason is string { + span.addFinishReason(finishReason); + } chat:ChatCompletionMessageToolCall tool = toolCalls[0]; map|error arguments = tool.'function.arguments.fromJsonStringWithType(); @@ -283,8 +333,7 @@ isolated function generateLlmResponse(chat:Client llmClient, string deploymentId return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM); } - anydata|error res = parseResponseAsType(arguments.toJsonString(), expectedResponseTypedesc, - ResponseSchema.isOriginallyJsonObject); + anydata|error res = parseResponseAsType(arguments.toJsonString(), expectedResponseTypedesc, isOriginallyJsonObject); if res is error { return error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${ expectedResponseTypedesc.toBalString()}', found '${res.toBalString()}'`); diff --git a/gradle.properties b/gradle.properties index 6dd0237..87b86b4 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,6 +1,6 @@ org.gradle.caching=true group=io.ballerina.lib -version=1.3.0-SNAPSHOT +version=1.4.0-SNAPSHOT ballerinaLangVersion=2201.12.0 shadowJarPluginVersion=8.1.1