diff --git a/ballerina/Ballerina.toml b/ballerina/Ballerina.toml index 5652317..ee6e8d0 100644 --- a/ballerina/Ballerina.toml +++ b/ballerina/Ballerina.toml @@ -7,7 +7,7 @@ icon="icon.png" name = "ai.deepseek" org = "ballerinax" repository = "https://github.com/ballerina-platform/module-ballerinax-ai.deepseek" -version = "1.0.4" +version = "1.1.0" [platform.java21] graalvmCompatible = true @@ -15,5 +15,5 @@ graalvmCompatible = true [[platform.java21.dependency]] groupId = "io.ballerina.lib" artifactId = "ai.deepseek-native" -version = "1.0.4" -path = "../native/build/libs/ai.deepseek-native-1.0.4.jar" +version = "1.1.0" +path = "../native/build/libs/ai.deepseek-native-1.1.0-SNAPSHOT.jar" diff --git a/ballerina/CompilerPlugin.toml b/ballerina/CompilerPlugin.toml index 7c8e2c4..936cad5 100644 --- a/ballerina/CompilerPlugin.toml +++ b/ballerina/CompilerPlugin.toml @@ -3,7 +3,7 @@ id = "ai-deepseek-compiler-plugin" class = "io.ballerina.lib.ai.deepseek.AiDeepseekCompilerPlugin" [[dependency]] -path = "../compiler-plugin/build/libs/ai.deepseek-compiler-plugin-1.0.4.jar" +path = "../compiler-plugin/build/libs/ai.deepseek-compiler-plugin-1.1.0-SNAPSHOT.jar" [[dependency]] path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar" diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index e1efd7a..886b3bc 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -10,7 +10,7 @@ distribution-version = "2201.12.0" [[package]] org = "ballerina" name = "ai" -version = "1.5.4" +version = "1.7.0" dependencies = [ {org = "ballerina", name = "constraint"}, {org = "ballerina", name = "data.jsondata"}, @@ -26,6 +26,7 @@ dependencies = [ {org = "ballerina", name = "math.vector"}, {org = "ballerina", name = "mcp"}, {org = "ballerina", name = "mime"}, + {org = "ballerina", name = "observe"}, {org = "ballerina", name = "time"}, {org = "ballerina", name = "url"}, {org = "ballerina", name = "uuid"}, @@ -33,7 +34,8 @@ dependencies = [ ] modules = [ {org = "ballerina", packageName = "ai", moduleName = "ai"}, - {org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"} + {org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"}, + {org = "ballerina", packageName = "ai", moduleName = "ai.observe"} ] [[package]] @@ -70,7 +72,7 @@ dependencies = [ [[package]] org = "ballerina" name = "crypto" -version = "2.9.1" +version = "2.9.2" dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "time"} @@ -79,7 +81,7 @@ dependencies = [ [[package]] org = "ballerina" name = "data.jsondata" -version = "1.1.2" +version = "1.1.3" dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.object"} @@ -88,7 +90,7 @@ dependencies = [ [[package]] org = "ballerina" name = "data.xmldata" -version = "1.5.0" +version = "1.5.2" dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.object"} @@ -108,7 +110,7 @@ dependencies = [ [[package]] org = "ballerina" name = "http" -version = "2.14.6" +version = "2.14.7" dependencies = [ {org = "ballerina", name = "auth"}, {org = "ballerina", name = "cache"}, @@ -256,7 +258,7 @@ dependencies = [ [[package]] org = "ballerina" name = "log" -version = "2.13.0" +version = "2.14.0" dependencies = [ {org = "ballerina", name = "io"}, {org = "ballerina", name = "jballerina.java"}, @@ -272,7 +274,7 @@ version = "1.2.0" [[package]] org = "ballerina" name = "mcp" -version = "1.0.0" +version = "1.0.2" dependencies = [ {org = "ballerina", name = "http"}, {org = "ballerina", name = "jballerina.java"}, @@ -306,7 +308,7 @@ dependencies = [ [[package]] org = "ballerina" name = "observe" -version = "1.5.0" +version = "1.6.0" dependencies = [ {org = "ballerina", name = "jballerina.java"} ] @@ -389,7 +391,7 @@ dependencies = [ [[package]] org = "ballerinax" name = "ai.deepseek" -version = "1.0.4" +version = "1.1.0" dependencies = [ {org = "ballerina", name = "ai"}, {org = "ballerina", name = "http"}, diff --git a/ballerina/model-provider.bal b/ballerina/model-provider.bal index bb3b365..7c0280a 100644 --- a/ballerina/model-provider.bal +++ b/ballerina/model-provider.bal @@ -15,6 +15,7 @@ // under the License. import ballerina/ai; +import ballerina/ai.observe; import ballerina/http; import ballerina/jballerina.java; import ballerina/time; @@ -86,7 +87,21 @@ public isolated client class ModelProvider { # + return - Function to be called, chat response or an error in-case of failures isolated remote function chat(ai:ChatMessage[]|ai:ChatUserMessage messages, ai:ChatCompletionFunctions[] tools, string? stop = ()) returns ai:ChatAssistantMessage|ai:Error { - DeepSeekChatRequestMessages[] deepseekPayloadMessages = check self.prepareDeepseekRequestMessages(messages); + observe:ChatSpan span = observe:createChatSpan(self.modelType); + span.addProvider("deepseek"); + if stop is string { + span.addStopSequence(stop); + } + span.addTemperature(self.temperature); + json|ai:Error inputMessage = convertMessageToJson(messages); + if inputMessage is json { + span.addInputMessages(inputMessage); + } + DeepSeekChatRequestMessages[]|ai:Error deepseekPayloadMessages = self.prepareDeepseekRequestMessages(messages); + if deepseekPayloadMessages is ai:Error { + span.close(deepseekPayloadMessages); + return deepseekPayloadMessages; + } DeepSeekChatCompletionRequest request = { temperature: self.temperature, @@ -97,6 +112,7 @@ public isolated client class ModelProvider { }; if tools.length() > 0 { + span.addTools(tools); DeepseekFunction[] deepseekFunctions = []; foreach ai:ChatCompletionFunctions toolFunction in tools { map? parameters = toolFunction.parameters; @@ -118,18 +134,45 @@ public isolated client class ModelProvider { DeepSeekChatCompletionResponse|error response = self.llmClient->/chat/completions.post(request); if response is error { - return error ai:LlmConnectionError("Error while connecting to the model", response); + ai:Error err = error ai:LlmConnectionError("Error while connecting to the model", response); + span.close(err); + return err; + } + + span.addResponseId(response.id); + int? inputTokens = response.usage?.prompt_tokens; + if inputTokens is int { + span.addInputTokenCount(inputTokens); + } + int? outputTokens = response.usage?.completion_tokens; + if outputTokens is int { + span.addOutputTokenCount(outputTokens); + } + + DeepseekChatResponseChoice[]? choices = response.choices; + string? finishReason = choices !is () && choices.length() > 0 ? choices[0].finish_reason : (); + if finishReason is string { + span.addFinishReason(finishReason); } - return self.getAssistantMessages(response); + ai:ChatAssistantMessage|ai:Error result = self.getAssistantMessages(response); + if result is ai:Error { + span.close(result); + return result; + } + + span.addOutputMessages(result); + span.addOutputType(observe:TEXT); + span.close(); + return result; } # Sends a chat request to the model and generates a value that belongs to the type # corresponding to the type descriptor argument. - # + # # + prompt - The prompt to use in the chat messages # + td - Type descriptor specifying the expected return type format # + return - Generates a value that belongs to the type, or an error if generation fails - isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc td = <>) + isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc td = <>) returns td|ai:Error = @java:Method { 'class: "io.ballerina.lib.ai.deepseek.Generator" } external; @@ -307,3 +350,14 @@ isolated function getChatMessageStringContent(ai:Prompt|string prompt) returns s } return promptStr.trim(); } + +isolated function convertMessageToJson(ai:ChatMessage[]|ai:ChatMessage messages) returns json|ai:Error { + if messages is ai:ChatMessage[] { + return messages.'map(msg => msg is ai:ChatUserMessage|ai:ChatSystemMessage ? check convertMessageToJson(msg) : msg); + } + if messages is ai:ChatUserMessage|ai:ChatSystemMessage { + + } + return messages !is ai:ChatUserMessage|ai:ChatSystemMessage ? messages : + {role: messages.role, content: check getChatMessageStringContent(messages.content), name: messages.name}; +} diff --git a/ballerina/provider_utils.bal b/ballerina/provider_utils.bal index 99edc8f..ff91e2f 100644 --- a/ballerina/provider_utils.bal +++ b/ballerina/provider_utils.bal @@ -15,6 +15,7 @@ // under the License. import ballerina/ai; +import ballerina/ai.observe; import ballerina/http; type ResponseSchema record {| @@ -83,17 +84,17 @@ isolated function getGetResultsToolChoice() returns DeepSeekToolChoice => { } }; -isolated function getGetResultsTool(map parameters) returns DeepseekTool[]|error => +isolated function getGetResultsTool(map parameters) returns DeepseekTool[] => [ - { - 'type: FUNCTION, - 'function: { - name: GET_RESULTS_TOOL, - parameters: parameters, - description: "Tool to call with the response from a large language model (LLM) for a user prompt." - } + { + 'type: FUNCTION, + 'function: { + name: GET_RESULTS_TOOL, + parameters: parameters, + description: "Tool to call with the response from a large language model (LLM) for a user prompt." } - ]; + } +]; isolated function generateChatCreationContent(ai:Prompt prompt) returns string|ai:Error { string[] & readonly strings = prompt.strings; @@ -110,8 +111,8 @@ isolated function generateChatCreationContent(ai:Prompt prompt) returns string|a if insertion is ai:TextDocument[] { foreach ai:TextDocument doc in insertion { - promptStr += doc.content + " "; - + promptStr += doc.content + " "; + } promptStr += str; continue; @@ -136,18 +137,21 @@ isolated function handleParseResponseError(error chatResponseError) returns erro isolated function generateLlmResponse(http:Client llmClient, int maxTokens, DEEPSEEK_MODEL_NAMES modelType, decimal temperature, ai:Prompt prompt, typedesc expectedResponseTypedesc) returns anydata|ai:Error { - string content = check generateChatCreationContent(prompt); - ResponseSchema ResponseSchema = check getExpectedResponseSchema(expectedResponseTypedesc); - DeepseekTool[]|error tools = getGetResultsTool(ResponseSchema.schema); - if tools is error { - return error("Error in generated schema: " + tools.message()); + observe:GenerateContentSpan span = observe:createGenerateContentSpan(modelType); + span.addProvider("deepseek"); + + string content; + ResponseSchema responseSchema; + do { + content = check generateChatCreationContent(prompt); + responseSchema = check getExpectedResponseSchema(expectedResponseTypedesc); + } on fail ai:Error err { + span.close(err); + return err; } - - DeepseekChatUserMessage[] messages = [{ - role: ai:USER, - content - }]; + DeepseekTool[] tools = getGetResultsTool(responseSchema.schema); + DeepseekChatUserMessage[] messages = [{role: ai:USER, content}]; DeepSeekChatCompletionRequest request = { messages, model: modelType, @@ -156,42 +160,67 @@ isolated function generateLlmResponse(http:Client llmClient, int maxTokens, DEEP tools, toolChoice: getGetResultsToolChoice() }; + span.addInputMessages(messages); DeepSeekChatCompletionResponse|error response = llmClient->/chat/completions.post(request); if response is error { - return error ai:LlmConnectionError("Error while connecting to the model", response); + ai:Error err = error ai:LlmConnectionError("Error while connecting to the model", response); + span.close(err); + return err; } - DeepseekChatResponseChoice[]? choices = response.choices; + span.addResponseId(response.id); + int? inputTokens = response.usage?.prompt_tokens; + if inputTokens is int { + span.addInputTokenCount(inputTokens); + } + int? outputTokens = response.usage?.completion_tokens; + if outputTokens is int { + span.addOutputTokenCount(outputTokens); + } + DeepseekChatResponseChoice[]? choices = response.choices; if choices is () || choices.length() == 0 { - return error("No completion choices"); + ai:Error err = error("No completion choices"); + span.close(err); + return err; } DeepseekChatResponseMessage message = choices[0].message; DeepseekChatResponseToolCall[]? toolCalls = message?.tool_calls; - if toolCalls is () || toolCalls.length() == 0 { - return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM); + if toolCalls is () || toolCalls.length() == 0 { + ai:Error err = error(NO_RELEVANT_RESPONSE_FROM_THE_LLM); + span.close(err); + return err; } DeepseekChatResponseToolCall toolCall = toolCalls[0]; map|error arguments = toolCall.'function.arguments.fromJsonStringWithType(); if arguments is error { - return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM); + ai:Error err = error(NO_RELEVANT_RESPONSE_FROM_THE_LLM); + span.close(err); + return err; } anydata|error res = parseResponseAsType(arguments.toJsonString(), expectedResponseTypedesc, - ResponseSchema.isOriginallyJsonObject); + responseSchema.isOriginallyJsonObject); if res is error { - return error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${ + ai:Error err = error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${ expectedResponseTypedesc.toBalString()}', found '${res.toBalString()}'`); + span.close(err); + return err; } anydata|error result = res.ensureType(expectedResponseTypedesc); - if result is error { - return error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${ + ai:Error err = error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${ expectedResponseTypedesc.toBalString()}', found '${(typeof response).toBalString()}'`); + span.close(err); + return err; } + + span.addOutputMessages(result.toJson()); + span.addOutputType(observe:JSON); + span.close(); return result; } diff --git a/ballerina/types.bal b/ballerina/types.bal index 2f932a8..548eac0 100644 --- a/ballerina/types.bal +++ b/ballerina/types.bal @@ -102,11 +102,19 @@ type DeepseekChatResponseMessage record { type DeepseekChatResponseChoice record { DeepseekChatResponseMessage message; + string finish_reason?; }; +// https://api-docs.deepseek.com/api/create-chat-completion#responses type DeepSeekChatCompletionResponse record { string id; DeepseekChatResponseChoice[] choices; + DeepSeekUsage usage?; +}; + +type DeepSeekUsage record { + int prompt_tokens; + int completion_tokens; }; type DeepseekChatSystemMessage record {| diff --git a/gradle.properties b/gradle.properties index 3a2f8c5..5bbc77c 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,6 +1,6 @@ org.gradle.caching=true group=io.ballerina.lib -version=1.0.5-SNAPSHOT +version=1.1.0-SNAPSHOT ballerinaLangVersion=2201.12.0 shadowJarPluginVersion=8.1.1