Skip to content
Merged
6 changes: 3 additions & 3 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ icon="icon.png"
name = "ai.azure"
org = "ballerinax"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.azure"
version = "1.3.0"
version = "1.4.0"

[platform.java21]
graalvmCompatible = true

[[platform.java21.dependency]]
groupId = "io.ballerina.lib"
artifactId = "ai.azure-native"
version = "1.3.0"
path = "../native/build/libs/ai.azure-native-1.3.0-SNAPSHOT.jar"
version = "1.4.0"
path = "../native/build/libs/ai.azure-native-1.4.0-SNAPSHOT.jar"
2 changes: 1 addition & 1 deletion ballerina/CompilerPlugin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ id = "ai-azure-compiler-plugin"
class = "io.ballerina.lib.ai.azure.AiAzureCompilerPlugin"

[[dependency]]
path = "../compiler-plugin/build/libs/ai.azure-compiler-plugin-1.3.0-SNAPSHOT.jar"
path = "../compiler-plugin/build/libs/ai.azure-compiler-plugin-1.4.0-SNAPSHOT.jar"

[[dependency]]
path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar"
14 changes: 8 additions & 6 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

[ballerina]
dependencies-toml-version = "2"
distribution-version = "2201.12.9"
distribution-version = "2201.12.0"

[[package]]
org = "ballerina"
name = "ai"
version = "1.5.4"
version = "1.7.0"
dependencies = [
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "data.jsondata"},
Expand All @@ -26,14 +26,16 @@ dependencies = [
{org = "ballerina", name = "math.vector"},
{org = "ballerina", name = "mcp"},
{org = "ballerina", name = "mime"},
{org = "ballerina", name = "observe"},
{org = "ballerina", name = "time"},
{org = "ballerina", name = "url"},
{org = "ballerina", name = "uuid"},
{org = "ballerina", name = "yaml"}
]
modules = [
{org = "ballerina", packageName = "ai", moduleName = "ai"},
{org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"}
{org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"},
{org = "ballerina", packageName = "ai", moduleName = "ai.observe"}
]

[[package]]
Expand Down Expand Up @@ -281,7 +283,7 @@ version = "1.2.0"
[[package]]
org = "ballerina"
name = "mcp"
version = "1.0.1"
version = "1.0.2"
dependencies = [
{org = "ballerina", name = "http"},
{org = "ballerina", name = "jballerina.java"},
Expand Down Expand Up @@ -315,7 +317,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "observe"
version = "1.5.1"
version = "1.6.0"
dependencies = [
{org = "ballerina", name = "jballerina.java"}
]
Expand Down Expand Up @@ -407,7 +409,7 @@ dependencies = [
[[package]]
org = "ballerinax"
name = "ai.azure"
version = "1.3.0"
version = "1.4.0"
dependencies = [
{org = "ballerina", name = "ai"},
{org = "ballerina", name = "constraint"},
Expand Down
50 changes: 42 additions & 8 deletions ballerina/embedding_provider.bal
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// under the License.

import ballerina/ai;
import ballerina/ai.observe;
import ballerinax/azure.openai.embeddings;

# EmbeddingProvider provides an interface for interacting with Azure OpenAI Embedding Models.
Expand Down Expand Up @@ -76,19 +77,39 @@ public distinct isolated client class EmbeddingProvider {
# + chunk - The `ai:Chunk` containing the content to embed
# + return - The resulting `ai:Embedding` on success; otherwise, returns an `ai:Error`
isolated remote function embed(ai:Chunk chunk) returns ai:Embedding|ai:Error {
observe:EmbeddingSpan span = observe:createEmbeddingSpan(self.deploymentId);
span.addProvider("azure.ai.openai");

if chunk !is ai:TextDocument|ai:TextChunk {
return error ai:Error("Unsupported document type. only 'ai:TextDocument' or 'ai:TextChunk' is supported");
ai:Error err = error ai:Error("Unsupported chunk type. only 'ai:TextDocument|ai:TextChunk' is supported");
span.close(err);
return err;
}

do {
span.addInputContent(chunk.content);
embeddings:Inline_response_200 response = check self.embeddingsClient->/deployments/[self.deploymentId]/embeddings.post(
apiVersion = self.apiVersion,
payload = {
input: chunk.content
}
);
return check response.data[0].embedding.cloneWithType();

span.addResponseModel(response.model);
span.addInputTokenCount(response.usage.prompt_tokens);
if response.data.length() == 0 {
ai:Error err = error("No embeddings generated for the provided chunk");
span.close(err);
return err;
}

ai:Embedding embedding = check response.data[0].embedding.cloneWithType();
span.close();
return embedding;
} on fail error e {
return error ai:Error("Unable to obtain embedding for the provided chunk", e);
ai:Error err = error ai:Error("Unable to obtain embedding for the provided chunk", e);
span.close(err);
return err;
}
}

Expand All @@ -97,10 +118,18 @@ public distinct isolated client class EmbeddingProvider {
# + chunks - The array of chunks to be converted into embeddings
# + return - An array of embeddings on success, or an `ai:Error`
isolated remote function batchEmbed(ai:Chunk[] chunks) returns ai:Embedding[]|ai:Error {
observe:EmbeddingSpan span = observe:createEmbeddingSpan(self.deploymentId);
span.addProvider("azure.ai.openai");

if !chunks.every(chunk => chunk is ai:TextChunk|ai:TextDocument) {
return error("Unsupported chunk type. only 'ai:TextChunk[]|ai:TextDocument[]' is supported");
ai:Error err = error("Unsupported chunk type. only 'ai:TextChunk[]|ai:TextDocument[]' is supported");
span.close(err);
return err;
}
do {
string[] input = chunks.map(chunk => chunk.content.toString());
span.addInputContent(input);

embeddings:InputItemsString[] inputItems = from ai:Chunk chunk in chunks
select check chunk.content.cloneWithType();
embeddings:Inline_response_200 response = check self.embeddingsClient->/deployments/[self.deploymentId]/embeddings.post(
Expand All @@ -109,11 +138,16 @@ public distinct isolated client class EmbeddingProvider {
input: inputItems
}
);
return
from embeddings:Inline_response_200_data data in response.data
select check data.embedding.cloneWithType();

span.addInputTokenCount(response.usage.prompt_tokens);
ai:Embedding[] embeddings = from embeddings:Inline_response_200_data data in response.data
select check data.embedding.cloneWithType();
span.close();
return embeddings;
} on fail error e {
return error ai:Error("Unable to obtain embedding for the provided document", e);
ai:Error err = error("Unable to obtain embedding for the provided document", e);
span.close(err);
return err;
}
}
}
72 changes: 65 additions & 7 deletions ballerina/model-provider.bal
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// under the License.

import ballerina/ai;
import ballerina/ai.observe;
import ballerina/jballerina.java;
import ballerinax/azure.openai.chat;

Expand Down Expand Up @@ -89,6 +90,21 @@ public isolated client class OpenAiModelProvider {
# + return - Function to be called, chat response or an error in-case of failures
isolated remote function chat(ai:ChatMessage[]|ai:ChatUserMessage messages, ai:ChatCompletionFunctions[] tools, string? stop = ())
returns ai:ChatAssistantMessage|ai:Error {
observe:ChatSpan span = observe:createChatSpan(self.deploymentId);
span.addProvider("azure.ai.openai");
span.addOutputType(observe:TEXT);
if stop is string {
span.addStopSequence(stop);
}
span.addTemperature(self.temperature);
json|ai:Error jsonMsg = check convertMessageToJson(messages);
if jsonMsg is ai:Error {
ai:Error err = error("Error while transforming input", jsonMsg);
span.close(err);
return err;
}
span.addInputMessages(jsonMsg);

chat:CreateChatCompletionRequest request = {
stop,
messages: check self.mapToChatCompletionRequestMessage(messages),
Expand All @@ -97,11 +113,14 @@ public isolated client class OpenAiModelProvider {
};
if tools.length() > 0 {
request.functions = tools;
span.addTools(tools);
}
chat:CreateChatCompletionResponse|error response =
self.llmClient->/deployments/[self.deploymentId]/chat/completions.post(self.apiVersion, request);
if response is error {
return error ai:LlmConnectionError("Error while connecting to the model", response);
ai:Error err = error ai:LlmConnectionError("Error while connecting to the model", response);
span.close(err);
return err;
}

record {|
Expand All @@ -113,24 +132,52 @@ public isolated client class OpenAiModelProvider {
|}[]? choices = response.choices;

if choices is () || choices.length() == 0 {
return error ai:LlmInvalidResponseError("Empty response from the model when using function call API");
ai:Error err = error ai:LlmInvalidResponseError("Empty response from the model when using function call API");
span.close(err);
return err;
}

string|int? responseId = response.id;
if responseId is string {
span.addResponseId(responseId);
}
int? inputTokens = response.usage?.prompt_tokens;
if inputTokens is int {
span.addInputTokenCount(inputTokens);
}
int? outputTokens = response.usage?.completion_tokens;
if outputTokens is int {
span.addOutputTokenCount(outputTokens);
}
string? finishReason = choices[0].finish_reason;
if finishReason is string {
span.addFinishReason(finishReason);
}

chat:ChatCompletionResponseMessage? message = choices[0].message;
ai:ChatAssistantMessage chatAssistantMessage = {role: ai:ASSISTANT, content: message?.content};
chat:ChatCompletionFunctionCall? functionCall = message?.function_call;
if functionCall is chat:ChatCompletionFunctionCall {
chatAssistantMessage.toolCalls = [check self.mapToFunctionCall(functionCall)];
if functionCall is () {
span.addOutputMessages(chatAssistantMessage);
span.close();
return chatAssistantMessage;
}
ai:FunctionCall|ai:Error toolCall = check self.mapToFunctionCall(functionCall);
if toolCall is ai:Error {
span.close(toolCall);
return toolCall;
}
chatAssistantMessage.toolCalls = [toolCall];
return chatAssistantMessage;
}

# Sends a chat request to the model and generates a value that belongs to the type
# corresponding to the type descriptor argument.
#
#
# + prompt - The prompt to use in the chat messages
# + td - Type descriptor specifying the expected return type format
# + return - Generates a value that belongs to the type, or an error if generation fails
isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc<anydata> td = <>)
isolated remote function generate(ai:Prompt prompt, @display {label: "Expected type"} typedesc<anydata> td = <>)
returns td|ai:Error = @java:Method {
'class: "io.ballerina.lib.ai.azure.Generator"
} external;
Expand Down Expand Up @@ -158,7 +205,7 @@ public isolated client class OpenAiModelProvider {
assistantMessage["content"] = message?.content;
}
chatCompletionRequestMessages.push(assistantMessage);
} else if message is ai:ChatFunctionMessage {
} else {
chatCompletionRequestMessages.push(message);
}
}
Expand Down Expand Up @@ -233,3 +280,14 @@ isolated function getChatMessageStringContent(ai:Prompt|string prompt) returns s
}
return promptStr.trim();
}

isolated function convertMessageToJson(ai:ChatMessage[]|ai:ChatMessage messages) returns json|ai:Error {
if messages is ai:ChatMessage[] {
return messages.'map(msg => msg is ai:ChatUserMessage|ai:ChatSystemMessage ? check convertMessageToJson(msg) : msg);
}
if messages is ai:ChatUserMessage|ai:ChatSystemMessage {

}
return messages !is ai:ChatUserMessage|ai:ChatSystemMessage ? messages :
{role: messages.role, content: check getChatMessageStringContent(messages.content), name: messages.name};
}
Loading
Loading