Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ icon="icon.png"
name = "ai.deepseek"
org = "ballerinax"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.deepseek"
version = "1.0.4"
version = "1.1.0"

[platform.java21]
graalvmCompatible = true

[[platform.java21.dependency]]
groupId = "io.ballerina.lib"
artifactId = "ai.deepseek-native"
version = "1.0.4"
path = "../native/build/libs/ai.deepseek-native-1.0.4.jar"
version = "1.1.0"
path = "../native/build/libs/ai.deepseek-native-1.1.0-SNAPSHOT.jar"
2 changes: 1 addition & 1 deletion ballerina/CompilerPlugin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ id = "ai-deepseek-compiler-plugin"
class = "io.ballerina.lib.ai.deepseek.AiDeepseekCompilerPlugin"

[[dependency]]
path = "../compiler-plugin/build/libs/ai.deepseek-compiler-plugin-1.0.4.jar"
path = "../compiler-plugin/build/libs/ai.deepseek-compiler-plugin-1.1.0-SNAPSHOT.jar"

[[dependency]]
path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar"
22 changes: 12 additions & 10 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ distribution-version = "2201.12.0"
[[package]]
org = "ballerina"
name = "ai"
version = "1.5.4"
version = "1.7.0"
dependencies = [
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "data.jsondata"},
Expand All @@ -26,14 +26,16 @@ dependencies = [
{org = "ballerina", name = "math.vector"},
{org = "ballerina", name = "mcp"},
{org = "ballerina", name = "mime"},
{org = "ballerina", name = "observe"},
{org = "ballerina", name = "time"},
{org = "ballerina", name = "url"},
{org = "ballerina", name = "uuid"},
{org = "ballerina", name = "yaml"}
]
modules = [
{org = "ballerina", packageName = "ai", moduleName = "ai"},
{org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"}
{org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"},
{org = "ballerina", packageName = "ai", moduleName = "ai.observe"}
]

[[package]]
Expand Down Expand Up @@ -70,7 +72,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "crypto"
version = "2.9.1"
version = "2.9.2"
dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "time"}
Expand All @@ -79,7 +81,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "data.jsondata"
version = "1.1.2"
version = "1.1.3"
dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "lang.object"}
Expand All @@ -88,7 +90,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "data.xmldata"
version = "1.5.0"
version = "1.5.2"
dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "lang.object"}
Expand All @@ -108,7 +110,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "http"
version = "2.14.6"
version = "2.14.7"
dependencies = [
{org = "ballerina", name = "auth"},
{org = "ballerina", name = "cache"},
Expand Down Expand Up @@ -256,7 +258,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "log"
version = "2.13.0"
version = "2.14.0"
dependencies = [
{org = "ballerina", name = "io"},
{org = "ballerina", name = "jballerina.java"},
Expand All @@ -272,7 +274,7 @@ version = "1.2.0"
[[package]]
org = "ballerina"
name = "mcp"
version = "1.0.0"
version = "1.0.2"
dependencies = [
{org = "ballerina", name = "http"},
{org = "ballerina", name = "jballerina.java"},
Expand Down Expand Up @@ -306,7 +308,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "observe"
version = "1.5.0"
version = "1.6.0"
dependencies = [
{org = "ballerina", name = "jballerina.java"}
]
Expand Down Expand Up @@ -389,7 +391,7 @@ dependencies = [
[[package]]
org = "ballerinax"
name = "ai.deepseek"
version = "1.0.4"
version = "1.1.0"
dependencies = [
{org = "ballerina", name = "ai"},
{org = "ballerina", name = "http"},
Expand Down
64 changes: 59 additions & 5 deletions ballerina/model-provider.bal
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// under the License.

import ballerina/ai;
import ballerina/ai.observe;
import ballerina/http;
import ballerina/jballerina.java;
import ballerina/time;
Expand Down Expand Up @@ -86,7 +87,21 @@ public isolated client class ModelProvider {
# + return - Function to be called, chat response or an error in-case of failures
isolated remote function chat(ai:ChatMessage[]|ai:ChatUserMessage messages, ai:ChatCompletionFunctions[] tools, string? stop = ())
returns ai:ChatAssistantMessage|ai:Error {
DeepSeekChatRequestMessages[] deepseekPayloadMessages = check self.prepareDeepseekRequestMessages(messages);
observe:ChatSpan span = observe:createChatSpan(self.modelType);
span.addProvider("deepseek");
if stop is string {
span.addStopSequence(stop);
}
span.addTemperature(self.temperature);
json|ai:Error inputMessage = convertMessageToJson(messages);
if inputMessage is json {
span.addInputMessages(inputMessage);
}
DeepSeekChatRequestMessages[]|ai:Error deepseekPayloadMessages = self.prepareDeepseekRequestMessages(messages);
if deepseekPayloadMessages is ai:Error {
span.close(deepseekPayloadMessages);
return deepseekPayloadMessages;
}

DeepSeekChatCompletionRequest request = {
temperature: self.temperature,
Expand All @@ -97,6 +112,7 @@ public isolated client class ModelProvider {
};

if tools.length() > 0 {
span.addTools(tools);
DeepseekFunction[] deepseekFunctions = [];
foreach ai:ChatCompletionFunctions toolFunction in tools {
map<json>? parameters = toolFunction.parameters;
Expand All @@ -118,18 +134,45 @@ public isolated client class ModelProvider {

DeepSeekChatCompletionResponse|error response = self.llmClient->/chat/completions.post(request);
if response is error {
return error ai:LlmConnectionError("Error while connecting to the model", response);
ai:Error err = error ai:LlmConnectionError("Error while connecting to the model", response);
span.close(err);
return err;
}

span.addResponseId(response.id);
int? inputTokens = response.usage?.prompt_tokens;
if inputTokens is int {
span.addInputTokenCount(inputTokens);
}
int? outputTokens = response.usage?.completion_tokens;
if outputTokens is int {
span.addOutputTokenCount(outputTokens);
}

DeepseekChatResponseChoice[]? choices = response.choices;
string? finishReason = choices !is () && choices.length() > 0 ? choices[0].finish_reason : ();
if finishReason is string {
span.addFinishReason(finishReason);
}
return self.getAssistantMessages(response);
ai:ChatAssistantMessage|ai:Error result = self.getAssistantMessages(response);
if result is ai:Error {
span.close(result);
return result;
}

span.addOutputMessages(result);
span.addOutputType(observe:TEXT);
span.close();
return result;
}

# Sends a chat request to the model and generates a value that belongs to the type
# corresponding to the type descriptor argument.
#
#
# + prompt - The prompt to use in the chat messages
# + td - Type descriptor specifying the expected return type format
# + return - Generates a value that belongs to the type, or an error if generation fails
isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc<anydata> td = <>)
isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc<anydata> td = <>)
returns td|ai:Error = @java:Method {
'class: "io.ballerina.lib.ai.deepseek.Generator"
} external;
Expand Down Expand Up @@ -307,3 +350,14 @@ isolated function getChatMessageStringContent(ai:Prompt|string prompt) returns s
}
return promptStr.trim();
}

isolated function convertMessageToJson(ai:ChatMessage[]|ai:ChatMessage messages) returns json|ai:Error {
if messages is ai:ChatMessage[] {
return messages.'map(msg => msg is ai:ChatUserMessage|ai:ChatSystemMessage ? check convertMessageToJson(msg) : msg);
}
if messages is ai:ChatUserMessage|ai:ChatSystemMessage {

}
return messages !is ai:ChatUserMessage|ai:ChatSystemMessage ? messages :
{role: messages.role, content: check getChatMessageStringContent(messages.content), name: messages.name};
}
91 changes: 60 additions & 31 deletions ballerina/provider_utils.bal
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// under the License.

import ballerina/ai;
import ballerina/ai.observe;
import ballerina/http;

type ResponseSchema record {|
Expand Down Expand Up @@ -83,17 +84,17 @@ isolated function getGetResultsToolChoice() returns DeepSeekToolChoice => {
}
};

isolated function getGetResultsTool(map<json> parameters) returns DeepseekTool[]|error =>
isolated function getGetResultsTool(map<json> parameters) returns DeepseekTool[] =>
[
{
'type: FUNCTION,
'function: {
name: GET_RESULTS_TOOL,
parameters: parameters,
description: "Tool to call with the response from a large language model (LLM) for a user prompt."
}
{
'type: FUNCTION,
'function: {
name: GET_RESULTS_TOOL,
parameters: parameters,
description: "Tool to call with the response from a large language model (LLM) for a user prompt."
}
];
}
];

isolated function generateChatCreationContent(ai:Prompt prompt) returns string|ai:Error {
string[] & readonly strings = prompt.strings;
Expand All @@ -110,8 +111,8 @@ isolated function generateChatCreationContent(ai:Prompt prompt) returns string|a

if insertion is ai:TextDocument[] {
foreach ai:TextDocument doc in insertion {
promptStr += doc.content + " ";
promptStr += doc.content + " ";

}
promptStr += str;
continue;
Expand All @@ -136,18 +137,21 @@ isolated function handleParseResponseError(error chatResponseError) returns erro

isolated function generateLlmResponse(http:Client llmClient, int maxTokens, DEEPSEEK_MODEL_NAMES modelType,
decimal temperature, ai:Prompt prompt, typedesc<json> expectedResponseTypedesc) returns anydata|ai:Error {
string content = check generateChatCreationContent(prompt);
ResponseSchema ResponseSchema = check getExpectedResponseSchema(expectedResponseTypedesc);
DeepseekTool[]|error tools = getGetResultsTool(ResponseSchema.schema);
if tools is error {
return error("Error in generated schema: " + tools.message());
observe:GenerateContentSpan span = observe:createGenerateContentSpan(modelType);
span.addProvider("deepseek");

string content;
ResponseSchema responseSchema;
do {
content = check generateChatCreationContent(prompt);
responseSchema = check getExpectedResponseSchema(expectedResponseTypedesc);
} on fail ai:Error err {
span.close(err);
return err;
}


DeepseekChatUserMessage[] messages = [{
role: ai:USER,
content
}];
DeepseekTool[] tools = getGetResultsTool(responseSchema.schema);
DeepseekChatUserMessage[] messages = [{role: ai:USER, content}];
DeepSeekChatCompletionRequest request = {
messages,
model: modelType,
Expand All @@ -156,42 +160,67 @@ isolated function generateLlmResponse(http:Client llmClient, int maxTokens, DEEP
tools,
toolChoice: getGetResultsToolChoice()
};
span.addInputMessages(messages);

DeepSeekChatCompletionResponse|error response = llmClient->/chat/completions.post(request);
if response is error {
return error ai:LlmConnectionError("Error while connecting to the model", response);
ai:Error err = error ai:LlmConnectionError("Error while connecting to the model", response);
span.close(err);
return err;
}

DeepseekChatResponseChoice[]? choices = response.choices;
span.addResponseId(response.id);
int? inputTokens = response.usage?.prompt_tokens;
if inputTokens is int {
span.addInputTokenCount(inputTokens);
}
int? outputTokens = response.usage?.completion_tokens;
if outputTokens is int {
span.addOutputTokenCount(outputTokens);
}

DeepseekChatResponseChoice[]? choices = response.choices;
if choices is () || choices.length() == 0 {
return error("No completion choices");
ai:Error err = error("No completion choices");
span.close(err);
return err;
}

DeepseekChatResponseMessage message = choices[0].message;
DeepseekChatResponseToolCall[]? toolCalls = message?.tool_calls;
if toolCalls is () || toolCalls.length() == 0 {
return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM);
if toolCalls is () || toolCalls.length() == 0 {
ai:Error err = error(NO_RELEVANT_RESPONSE_FROM_THE_LLM);
span.close(err);
return err;
}

DeepseekChatResponseToolCall toolCall = toolCalls[0];
map<json>|error arguments = toolCall.'function.arguments.fromJsonStringWithType();
if arguments is error {
return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM);
ai:Error err = error(NO_RELEVANT_RESPONSE_FROM_THE_LLM);
span.close(err);
return err;
}

anydata|error res = parseResponseAsType(arguments.toJsonString(), expectedResponseTypedesc,
ResponseSchema.isOriginallyJsonObject);
responseSchema.isOriginallyJsonObject);
if res is error {
return error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${
ai:Error err = error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${
expectedResponseTypedesc.toBalString()}', found '${res.toBalString()}'`);
span.close(err);
return err;
}

anydata|error result = res.ensureType(expectedResponseTypedesc);

if result is error {
return error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${
ai:Error err = error ai:LlmInvalidGenerationError(string `Invalid value returned from the LLM Client, expected: '${
expectedResponseTypedesc.toBalString()}', found '${(typeof response).toBalString()}'`);
span.close(err);
return err;
}

span.addOutputMessages(result.toJson());
span.addOutputType(observe:JSON);
span.close();
return result;
}
Loading
Loading