Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ license = ["Apache-2.0"]
name = "ai.model.provider.ollama"
org = "ballerinax"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama"
version = "1.0.0"
version = "1.0.1"

[platform.java21]
graalvmCompatible = true
[[platform.java21.dependency]]
groupId = "io.ballerina.lib"
artifactId = "ai.model.provider.ollama-native"
version = "1.0.1"
path = "../native/build/libs/ai.model.provider.ollama-native-1.0.1-SNAPSHOT.jar"
9 changes: 9 additions & 0 deletions ballerina/CompilerPlugin.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[plugin]
id = "ai-compiler-plugin"
class = "io.ballerina.lib.ai.ollama.AiOllamaCompilerPlugin"

[[dependency]]
path = "../compiler-plugin/build/libs/ai.model.provider.ollama-compiler-plugin-1.0.1-SNAPSHOT.jar"

[[dependency]]
path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar"
11 changes: 10 additions & 1 deletion ballerina/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ plugins{
id 'io.ballerina.plugin'
}

description = 'Azure Model Provider - Ballerina'
description = 'Ollama Model Provider - Ballerina'

def packageName = "ai.model.provider.ollama"
def packageOrg = "ballerinax"
def tomlVersion = stripBallerinaExtensionVersion("${project.version}")
def ballerinaTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/Ballerina.toml")
def compilerPluginTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/CompilerPlugin.toml")
def ballerinaTomlFile = new File("$project.projectDir/Ballerina.toml")
def compilerPluginTomlFile = new File("$project.projectDir/CompilerPlugin.toml")

def stripBallerinaExtensionVersion(String extVersion) {
if (extVersion.matches(project.ext.timestampedVersionRegex)) {
Expand Down Expand Up @@ -57,6 +59,11 @@ task updateTomlFiles {
def newBallerinaToml = ballerinaTomlFilePlaceHolder.text.replace("@project.version@", project.version)
newBallerinaToml = newBallerinaToml.replace("@toml.version@", tomlVersion)
ballerinaTomlFile.text = newBallerinaToml

def ballerinaToOpenApiVersion = project.ballerinaToOpenApiVersion
def newCompilerPluginToml = compilerPluginTomlFilePlaceHolder.text.replace("@project.version@", project.version)
newCompilerPluginToml = newCompilerPluginToml.replace("@ballerinaToOpenApiVersion.version@", ballerinaToOpenApiVersion)
compilerPluginTomlFile.text = newCompilerPluginToml
}
}

Expand Down Expand Up @@ -95,6 +102,8 @@ clean {
delete 'build'
}

build.dependsOn ':ai.model.provider.ollama-native:build'
build.dependsOn "generatePomFileForMavenPublication"
build.dependsOn ":ai.model.provider.ollama-compiler-plugin:build"
publishToMavenLocal.dependsOn build
publish.dependsOn build
17 changes: 14 additions & 3 deletions ballerina/provider.bal
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ballerina/ai;
import ballerina/data.jsondata;
import ballerina/http;
import ballerina/jballerina.java;

const DEFAULT_OLLAMA_SERVICE_URL = "http://localhost:11434";
const TOOL_ROLE = "tool";
Expand All @@ -42,7 +43,7 @@ public isolated client class Provider {
http:ClientConfiguration clientConfig = {...connectionConfig};
http:Client|error ollamaClient = new (serviceUrl, clientConfig);
if ollamaClient is error {
return error ai:Error("Error while connecting to the model", ollamaClient);
return error("Error while connecting to the model", ollamaClient);
}
self.modleParameters = check getModelParameterMap(modleParameters);
self.ollamaClient = ollamaClient;
Expand All @@ -61,11 +62,21 @@ public isolated client class Provider {
json requestPayload = self.prepareRequestPayload(messages, tools, stop);
OllamaResponse|error response = self.ollamaClient->/api/chat.post(requestPayload);
if response is error {
return error ai:LlmConnectionError("Error while connecting to ollama", response);
return error("Error while connecting to ollama", response);
}
return self.mapOllamaResponseToAssistantMessage(response);
}

# Sends a chat request to the model and generates a value that belongs to the type
# corresponding to the type descriptor argument.
#
# + prompt - The prompt to use in the chat messages
# + td - Type descriptor specifying the expected return type format
# + return - Generates a value that belongs to the type, or an error if generation fails
isolated remote function generate(ai:Prompt prompt, typedesc<anydata> td = <>) returns td|ai:Error = @java:Method {
'class: "io.ballerina.lib.ai.ollama.Generator"
} external;

private isolated function prepareRequestPayload(ai:ChatMessage[] messages, ai:ChatCompletionFunctions[] tools,
string? stop) returns json {
json[] transformedMessages = messages.'map(isolated function(ai:ChatMessage message) returns json {
Expand Down Expand Up @@ -118,6 +129,6 @@ isolated function getModelParameterMap(OllamaModelParameters modleParameters) re
map<json> & readonly readonlyOptions = check options.cloneWithType();
return readonlyOptions;
} on fail error e {
return error ai:Error("Error while processing model parameters", e);
return error("Error while processing model parameters", e);
}
}
187 changes: 187 additions & 0 deletions ballerina/provider_utils.bal
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright (c) 2025 WSO2 LLC (http://www.wso2.com).
//
// WSO2 LLC. licenses this file to you under the Apache License,
// Version 2.0 (the "License"); you may not use this file except
// in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import ballerina/ai;
import ballerina/http;

type ResponseSchema record {|
map<json> schema;
boolean isOriginallyJsonObject = true;
|};

const JSON_CONVERSION_ERROR = "FromJsonStringError";
const CONVERSION_ERROR = "ConversionError";
const ERROR_MESSAGE = "Error occurred while attempting to parse the response from the " +
"LLM as the expected type. Retrying and/or validating the prompt could fix the response.";
const RESULT = "result";
const GET_RESULTS_TOOL = "getResults";
const FUNCTION = "function";
const NO_RELEVANT_RESPONSE_FROM_THE_LLM = "No relevant response from the LLM";

isolated function generateJsonObjectSchema(map<json> schema) returns ResponseSchema {
string[] supportedMetaDataFields = ["$schema", "$id", "$anchor", "$comment", "title", "description"];

if schema["type"] == "object" {
return {schema};
}

map<json> updatedSchema = map from var [key, value] in schema.entries()
where supportedMetaDataFields.indexOf(key) is int
select [key, value];

updatedSchema["type"] = "object";
map<json> content = map from var [key, value] in schema.entries()
where supportedMetaDataFields.indexOf(key) !is int
select [key, value];

updatedSchema["properties"] = {[RESULT]: content};

return {schema: updatedSchema, isOriginallyJsonObject: false};
}

isolated function parseResponseAsType(string resp,
typedesc<anydata> expectedResponseTypedesc, boolean isOriginallyJsonObject) returns anydata|error {
if !isOriginallyJsonObject {
map<json> respContent = check resp.fromJsonStringWithType();
anydata|error result = trap respContent[RESULT].fromJsonWithType(expectedResponseTypedesc);
if result is error {
return handleParseResponseError(result);
}
return result;
}

anydata|error result = resp.fromJsonStringWithType(expectedResponseTypedesc);
if result is error {
return handleParseResponseError(result);
}
return result;
}

isolated function getExpectedResponseSchema(typedesc<anydata> expectedResponseTypedesc) returns ResponseSchema|ai:Error {
// Restricted at compile-time for now.
typedesc<json> td = checkpanic expectedResponseTypedesc.ensureType();
return generateJsonObjectSchema(check generateJsonSchemaForTypedescAsJson(td));
}

isolated function getGetResultsTool(map<json> parameters) returns map<json>[]|error =>
[
{
'type: FUNCTION,
'function: {
name: GET_RESULTS_TOOL,
parameters: parameters,
description: string `Required Tool to call with the response from a large language model (LLM) for a user prompt.
This tool is mandatory for the LLM to return a response.`
}
}
];

isolated function generateChatCreationContent(ai:Prompt prompt) returns string|ai:Error {
string[] & readonly strings = prompt.strings;
anydata[] insertions = prompt.insertions;
string promptStr = strings[0];
foreach int i in 0 ..< insertions.length() {
string str = strings[i + 1];
anydata insertion = insertions[i];

if insertion is ai:TextDocument {
promptStr += insertion.content + " " + str;
continue;
}

if insertion is ai:TextDocument[] {
foreach ai:TextDocument doc in insertion {
promptStr += doc.content + " ";

}
promptStr += str;
continue;
}

if insertion is ai:Document {
return error ai:Error("Only Text Documents are currently supported.");
}

promptStr += insertion.toString() + str;
}
promptStr += addToolDirective();
return promptStr.trim();
}

isolated function addToolDirective() returns string {
return "\nYou must call the `getResults` tool to obtain the correct answer.";
}

isolated function handleParseResponseError(error chatResponseError) returns error {
string msg = chatResponseError.message();
if msg.includes(JSON_CONVERSION_ERROR) || msg.includes(CONVERSION_ERROR) {
return error(string `${ERROR_MESSAGE}`, detail = chatResponseError);
}
return chatResponseError;
}

isolated function generateLlmResponse(http:Client llmClient, string modelType,
readonly & map<json> modleParameters, ai:Prompt prompt,
typedesc<json> expectedResponseTypedesc) returns anydata|ai:Error {
string content = check generateChatCreationContent(prompt);
ResponseSchema ResponseSchema = check getExpectedResponseSchema(expectedResponseTypedesc);
map<json>[]|error tools = getGetResultsTool(ResponseSchema.schema);
if tools is error {
return error("Error while generating the tool: " + tools.message());
}

map<json> request = {
messages: [
{
role: ai:USER,
"content": content
}
],
tools,
model: modelType,
'stream: false,
options: {...modleParameters}
};

OllamaResponse|error response = llmClient->/api/chat.post(request);
if response is error {
return error("Error while connecting to ollama", response);
}

OllamaToolCall[]? toolCalls = response.message?.tool_calls;

if toolCalls is () || toolCalls.length() == 0 {
return error(NO_RELEVANT_RESPONSE_FROM_THE_LLM);
}

OllamaToolCall tool = toolCalls[0];
map<json> arguments = tool.'function.arguments;

anydata|error res = parseResponseAsType(arguments.toJsonString(), expectedResponseTypedesc,
ResponseSchema.isOriginallyJsonObject);
if res is error {
return error(string `Invalid value returned from the LLM Client, expected: '${
expectedResponseTypedesc.toBalString()}', found '${res.toBalString()}'`);
}

anydata|error result = res.ensureType(expectedResponseTypedesc);

if result is error {
return error(string `Invalid value returned from the LLM Client, expected: '${
expectedResponseTypedesc.toBalString()}', found '${(typeof response).toBalString()}'`);
}
return result;
}
41 changes: 41 additions & 0 deletions ballerina/tests/test_services.bal
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2025 WSO2 LLC. (http://www.wso2.org).
//
// WSO2 Inc. licenses this file to you under the Apache License,
// Version 2.0 (the "License"); you may not use this file except
// in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import ballerina/http;
import ballerina/test;

service /llm on new http:Listener(8080) {
resource function post api/chat(map<json> payload)
returns OllamaResponse|error {
test:assertEquals(payload.model, "llama2");
json[] messages = check payload.messages.ensureType();
json message = messages[0];

string content = check message.content.ensureType();
test:assertEquals(content, getExpectedPrompt(content));
test:assertEquals(message.role, "user");
json[] tools = check payload.tools.ensureType();
if tools.length() == 0 {
test:assertFail("No tools in the payload");
}

json tool = check tools[0].ensureType();
map<json> parameters = check (tool.'function?.parameters).ensureType();

test:assertEquals(parameters, getExpectedParameterSchema(content), string `Test failed for prompt:- ${content}`);
return getTestServiceResponse(content);
}
}
Loading
Loading