Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ jobs:
uses: ballerina-platform/ballerina-library/.github/workflows/build-connector-template.yml@main
secrets: inherit
with:
repo-name: module-ballerinax-ai.model.provider.ollama
repo-name: module-ballerinax-ai.ollama
publish-required: true
2 changes: 1 addition & 1 deletion .github/workflows/daily-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
uses: ballerina-platform/ballerina-library/.github/workflows/daily-build-connector-template.yml@main
secrets: inherit
with:
repo-name: module-ballerinax-ai.model.provider.ollama
repo-name: module-ballerinax-ai.ollama
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ jobs:
uses: ballerina-platform/ballerina-library/.github/workflows/release-package-connector-template.yml@main
secrets: inherit
with:
package-name: ai.model.provider.ollama
package-name: ai.ollama
package-org: ballerinax
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Ballerina Ollama Model Provider Library

[![Build](https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama/workflows/CI/badge.svg)](https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama/actions?query=workflow%3ACI)
[![GitHub Last Commit](https://img.shields.io/github/last-commit/ballerina-platform/module-ballerinax-ai.model.provider.ollama.svg)](https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama/commits/master)
[![Build](https://github.com/ballerina-platform/module-ballerinax-ai.ollama/workflows/CI/badge.svg)](https://github.com/ballerina-platform/module-ballerinax-ai.ollama/actions?query=workflow%3ACI)
[![GitHub Last Commit](https://img.shields.io/github/last-commit/ballerina-platform/module-ballerinax-ai.ollama.svg)](https://github.com/ballerina-platform/module-ballerinax-ai.ollama/commits/master)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

## Overview
Expand Down
11 changes: 7 additions & 4 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ authors = ["Ballerina"]
distribution = "2201.12.0"
keywords = ["AI", "Agent", "Ollama", "Model", "Provider"]
license = ["Apache-2.0"]
name = "ai.model.provider.ollama"
name = "ai.ollama"
org = "ballerinax"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.ollama"
version = "1.0.0"

[platform.java21]
graalvmCompatible = true
[[platform.java21.dependency]]
groupId = "io.ballerina.lib"
artifactId = "ai.ollama-native"
version = "1.0.0"
path = "../native/build/libs/ai.ollama-native-1.0.0-SNAPSHOT.jar"
9 changes: 9 additions & 0 deletions ballerina/CompilerPlugin.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[plugin]
id = "ai-ollama-compiler-plugin"
class = "io.ballerina.lib.ai.ollama.AiOllamaCompilerPlugin"

[[dependency]]
path = "../compiler-plugin/build/libs/ai.ollama-compiler-plugin-1.0.0-SNAPSHOT.jar"

[[dependency]]
path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar"
12 changes: 6 additions & 6 deletions ballerina/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ Ensure that your Ollama server is running locally before using this module in yo

## Quickstart

To use the `ai.model.provider.ollama` module in your Ballerina application, update the `.bal` file as follows:
To use the `ai.ollama` module in your Ballerina application, update the `.bal` file as follows:

### Step 1: Import the module

Import the `ai.model.provider.ollama;` module.
Import the `ai.ollama` module.

```ballerina
import ballerinax/ai.model.provider.ollama;
import ballerinax/ai.ollama;
```

### Step 2: Intialize the Model Provider
Expand All @@ -24,14 +24,14 @@ Here's how to initialize the Model Provider:

```ballerina
import ballerina/ai;
import ballerinax/ai.model.provider.ollama;
import ballerinax/ai.ollama;

final ai:ModelProvider ollamaModel = check new ollama:Provider("ollamaModelName");
final ai:ModelProvider ollamaModel = check new ollama:ModelProvider("ollamaModelName");
```

### Step 4: Invoke chat completion

```
```ballerina
ai:ChatMessage[] chatMessages = [{role: "user", content: "hi"}];
ai:ChatAssistantMessage response = check ollamaModel->chat(chatMessages, tools = []);

Expand Down
13 changes: 11 additions & 2 deletions ballerina/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ plugins{
id 'io.ballerina.plugin'
}

description = 'Azure Model Provider - Ballerina'
description = 'Ollama Model Provider - Ballerina'

def packageName = "ai.model.provider.ollama"
def packageName = "ai.ollama"
def packageOrg = "ballerinax"
def tomlVersion = stripBallerinaExtensionVersion("${project.version}")
def ballerinaTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/Ballerina.toml")
def compilerPluginTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/CompilerPlugin.toml")
def ballerinaTomlFile = new File("$project.projectDir/Ballerina.toml")
def compilerPluginTomlFile = new File("$project.projectDir/CompilerPlugin.toml")

def stripBallerinaExtensionVersion(String extVersion) {
if (extVersion.matches(project.ext.timestampedVersionRegex)) {
Expand Down Expand Up @@ -57,6 +59,11 @@ task updateTomlFiles {
def newBallerinaToml = ballerinaTomlFilePlaceHolder.text.replace("@project.version@", project.version)
newBallerinaToml = newBallerinaToml.replace("@toml.version@", tomlVersion)
ballerinaTomlFile.text = newBallerinaToml

def ballerinaToOpenApiVersion = project.ballerinaToOpenApiVersion
def newCompilerPluginToml = compilerPluginTomlFilePlaceHolder.text.replace("@project.version@", project.version)
newCompilerPluginToml = newCompilerPluginToml.replace("@ballerinaToOpenApiVersion.version@", ballerinaToOpenApiVersion)
compilerPluginTomlFile.text = newCompilerPluginToml
}
}

Expand Down Expand Up @@ -95,6 +102,8 @@ clean {
delete 'build'
}

build.dependsOn ':ai.ollama-native:build'
build.dependsOn "generatePomFileForMavenPublication"
build.dependsOn ":ai.ollama-compiler-plugin:build"
publishToMavenLocal.dependsOn build
publish.dependsOn build
113 changes: 95 additions & 18 deletions ballerina/provider.bal
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import ballerina/ai;
import ballerina/data.jsondata;
import ballerina/http;
import ballerina/jballerina.java;

const DEFAULT_OLLAMA_SERVICE_URL = "http://localhost:11434";
const TOOL_ROLE = "tool";

# Provider represents a client for interacting with an Ollama language models.
public isolated client class Provider {
public isolated client class ModelProvider {
*ai:ModelProvider;
private final http:Client ollamaClient;
private final string modelType;
Expand All @@ -42,7 +43,7 @@
http:ClientConfiguration clientConfig = {...connectionConfig};
http:Client|error ollamaClient = new (serviceUrl, clientConfig);
if ollamaClient is error {
return error ai:Error("Error while connecting to the model", ollamaClient);
return error("Error while connecting to the model", ollamaClient);

Check warning on line 46 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L46

Added line #L46 was not covered by tests
}
self.modleParameters = check getModelParameterMap(modleParameters);
self.ollamaClient = ollamaClient;
Expand All @@ -51,47 +52,82 @@

# Sends a chat request to the Ollama model with the given messages and tools.
#
# + messages - List of chat messages
# + messages - List of chat messages or user message
# + tools - Tool definitions to be used for the tool call
# + stop - Stop sequence to stop the completion
# + return - Function to be called, chat response or an error in-case of failures
isolated remote function chat(ai:ChatMessage[] messages, ai:ChatCompletionFunctions[] tools = [], string? stop = ())
returns ai:ChatAssistantMessage|ai:LlmError {
isolated remote function chat(ai:ChatMessage[]|ai:ChatUserMessage messages, ai:ChatCompletionFunctions[] tools = [],
string? stop = ()) returns ai:ChatAssistantMessage|ai:Error {

Check warning on line 60 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L59-L60

Added lines #L59 - L60 were not covered by tests
// Ollama chat completion API reference: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
json requestPayload = self.prepareRequestPayload(messages, tools, stop);
json requestPayload = check self.prepareRequestPayload(messages, tools, stop);
OllamaResponse|error response = self.ollamaClient->/api/chat.post(requestPayload);
if response is error {
return error ai:LlmConnectionError("Error while connecting to ollama", response);
return error("Error while connecting to ollama", response);

Check warning on line 65 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L65

Added line #L65 was not covered by tests
}
return self.mapOllamaResponseToAssistantMessage(response);
}

private isolated function prepareRequestPayload(ai:ChatMessage[] messages, ai:ChatCompletionFunctions[] tools,
string? stop) returns json {
json[] transformedMessages = messages.'map(isolated function(ai:ChatMessage message) returns json {
if message is ai:ChatFunctionMessage {
return {role: TOOL_ROLE, content: message?.content};
}
return message;
});
# Sends a chat request to the model and generates a value that belongs to the type
# corresponding to the type descriptor argument.
#
# + prompt - The prompt to use in the chat messages
# + td - Type descriptor specifying the expected return type format
# + return - Generates a value that belongs to the type, or an error if generation fails
isolated remote function generate(ai:Prompt prompt, typedesc<anydata> td = <>) returns td|ai:Error = @java:Method {
'class: "io.ballerina.lib.ai.ollama.Generator"
} external;

private isolated function prepareRequestPayload(ai:ChatMessage[]|ai:ChatUserMessage messages,
ai:ChatCompletionFunctions[] tools, string? stop) returns json|ai:Error {
map<json> options = {...self.modleParameters};
if stop is string {
options["stop"] = [stop];
}

map<json> payload = {
model: self.modelType,
messages: transformedMessages,
messages: check self.mapToOllamaRequestMessage(messages),
'stream: false,
options
};
if tools.length() > 0 {
payload["tools"] = tools.'map(tool => {'type: ai:FUNCTION, 'function: tool});
payload["tools"] = tools.'map(tool => {'type: FUNCTION, 'function: tool});

Check warning on line 94 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L94

Added line #L94 was not covered by tests
}
return payload;
}

private isolated function mapToOllamaRequestMessage(ai:ChatMessage[]|ai:ChatUserMessage messages)
returns json[]|ai:Error {
json[] transformedMessages = [];
if messages is ai:ChatUserMessage {
transformedMessages.push({
role: TOOL_ROLE,

Check warning on line 104 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L103-L104

Added lines #L103 - L104 were not covered by tests
content: check getChatMessageStringContent(messages?.content)
});
return transformedMessages;

Check warning on line 107 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L107

Added line #L107 was not covered by tests
}
foreach ai:ChatMessage message in messages {
if message is ai:ChatFunctionMessage {
transformedMessages.push({role: TOOL_ROLE, content: message?.content});

Check warning on line 111 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L111

Added line #L111 was not covered by tests

} else if message is ai:ChatUserMessage {
transformedMessages.push({
role: ai:USER,

Check warning on line 115 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L114-L115

Added lines #L114 - L115 were not covered by tests
content: check getChatMessageStringContent(message.content)
});

} else if message is ai:ChatSystemMessage {
transformedMessages.push({
role: ai:SYSTEM,

Check warning on line 121 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L120-L121

Added lines #L120 - L121 were not covered by tests
content: check getChatMessageStringContent(message.content)
});
} else if message is ai:ChatAssistantMessage {
transformedMessages.push(message);

Check warning on line 125 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L125

Added line #L125 was not covered by tests
}
}
return transformedMessages;
}

Check warning on line 129 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L128-L129

Added lines #L128 - L129 were not covered by tests

private isolated function mapOllamaResponseToAssistantMessage(OllamaResponse response)
returns ai:ChatAssistantMessage {
OllamaToolCall[]? toolCalls = response.message?.tool_calls;
Expand All @@ -118,6 +154,47 @@
map<json> & readonly readonlyOptions = check options.cloneWithType();
return readonlyOptions;
} on fail error e {
return error ai:Error("Error while processing model parameters", e);
return error("Error while processing model parameters", e);

Check warning on line 157 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L157

Added line #L157 was not covered by tests
}
}

isolated function getChatMessageStringContent(ai:Prompt|string prompt) returns string|ai:Error {
if prompt is string {
return prompt;

Check warning on line 163 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L163

Added line #L163 was not covered by tests
}
string[] & readonly strings = prompt.strings;
anydata[] insertions = prompt.insertions;
string promptStr = strings[0];

Check warning on line 167 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L165-L167

Added lines #L165 - L167 were not covered by tests
foreach int i in 0 ..< insertions.length() {
string str = strings[i + 1];
anydata insertion = insertions[i];

Check warning on line 170 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L169-L170

Added lines #L169 - L170 were not covered by tests

if insertion is ai:TextDocument|ai:TextChunk {
promptStr += insertion.content + " " + str;
continue;

Check warning on line 174 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L173-L174

Added lines #L173 - L174 were not covered by tests
}

if insertion is ai:TextDocument[] {
foreach ai:TextDocument doc in insertion {
promptStr += doc.content + " ";

Check warning on line 179 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L179

Added line #L179 was not covered by tests
}
promptStr += str;
continue;

Check warning on line 182 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L181-L182

Added lines #L181 - L182 were not covered by tests
}

if insertion is ai:TextChunk[] {
foreach ai:TextChunk doc in insertion {
promptStr += doc.content + " ";

Check warning on line 187 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L187

Added line #L187 was not covered by tests
}
promptStr += str;
continue;

Check warning on line 190 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L189-L190

Added lines #L189 - L190 were not covered by tests
}

if insertion is ai:Document {
return error ai:Error("Only Text Documents are currently supported.");

Check warning on line 194 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L194

Added line #L194 was not covered by tests
}

promptStr += insertion.toString() + str;

Check warning on line 197 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L197

Added line #L197 was not covered by tests
}
return promptStr.trim();

Check warning on line 199 in ballerina/provider.bal

View check run for this annotation

Codecov / codecov/patch

ballerina/provider.bal#L199

Added line #L199 was not covered by tests
}
Loading
Loading