Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ jobs:
uses: ballerina-platform/ballerina-library/.github/workflows/build-connector-template.yml@main
secrets: inherit
with:
repo-name: module-ballerinax-ai.model.provider.ollama
repo-name: module-ballerinax-ai.ollama
publish-required: true
2 changes: 1 addition & 1 deletion .github/workflows/daily-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
uses: ballerina-platform/ballerina-library/.github/workflows/daily-build-connector-template.yml@main
secrets: inherit
with:
repo-name: module-ballerinax-ai.model.provider.ollama
repo-name: module-ballerinax-ai.ollama
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ jobs:
uses: ballerina-platform/ballerina-library/.github/workflows/release-package-connector-template.yml@main
secrets: inherit
with:
package-name: ai.model.provider.ollama
package-name: ai.ollama
package-org: ballerinax
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Ballerina Ollama Model Provider Library

[![Build](https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama/workflows/CI/badge.svg)](https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama/actions?query=workflow%3ACI)
[![GitHub Last Commit](https://img.shields.io/github/last-commit/ballerina-platform/module-ballerinax-ai.model.provider.ollama.svg)](https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama/commits/master)
[![Build](https://github.com/ballerina-platform/module-ballerinax-ai.ollama/workflows/CI/badge.svg)](https://github.com/ballerina-platform/module-ballerinax-ai.ollama/actions?query=workflow%3ACI)
[![GitHub Last Commit](https://img.shields.io/github/last-commit/ballerina-platform/module-ballerinax-ai.ollama.svg)](https://github.com/ballerina-platform/module-ballerinax-ai.ollama/commits/master)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

## Overview
Expand Down
4 changes: 2 additions & 2 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ authors = ["Ballerina"]
distribution = "2201.12.0"
keywords = ["AI", "Agent", "Ollama", "Model", "Provider"]
license = ["Apache-2.0"]
name = "ai.model.provider.ollama"
name = "ai.ollama"
org = "ballerinax"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.ollama"
version = "1.0.0"

[platform.java21]
Expand Down
12 changes: 6 additions & 6 deletions ballerina/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ Ensure that your Ollama server is running locally before using this module in yo

## Quickstart

To use the `ai.model.provider.ollama` module in your Ballerina application, update the `.bal` file as follows:
To use the `ai.ollama` module in your Ballerina application, update the `.bal` file as follows:

### Step 1: Import the module

Import the `ai.model.provider.ollama;` module.
Import the `ai.ollama` module.

```ballerina
import ballerinax/ai.model.provider.ollama;
import ballerinax/ai.ollama;
```

### Step 2: Intialize the Model Provider
Expand All @@ -24,14 +24,14 @@ Here's how to initialize the Model Provider:

```ballerina
import ballerina/ai;
import ballerinax/ai.model.provider.ollama;
import ballerinax/ai.ollama;

final ai:ModelProvider ollamaModel = check new ollama:Provider("ollamaModelName");
final ai:ModelProvider ollamaModel = check new ollama:ModelProvider("ollamaModelName");
```

### Step 4: Invoke chat completion

```
```ballerina
ai:ChatMessage[] chatMessages = [{role: "user", content: "hi"}];
ai:ChatAssistantMessage response = check ollamaModel->chat(chatMessages, tools = []);

Expand Down
2 changes: 1 addition & 1 deletion ballerina/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ plugins{

description = 'Azure Model Provider - Ballerina'

def packageName = "ai.model.provider.ollama"
def packageName = "ai.ollama"
def packageOrg = "ballerinax"
def tomlVersion = stripBallerinaExtensionVersion("${project.version}")
def ballerinaTomlFilePlaceHolder = new File("${project.rootDir}/build-config/resources/Ballerina.toml")
Expand Down
101 changes: 85 additions & 16 deletions ballerina/provider.bal
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const DEFAULT_OLLAMA_SERVICE_URL = "http://localhost:11434";
const TOOL_ROLE = "tool";

# Provider represents a client for interacting with an Ollama language models.
public isolated client class Provider {
public isolated client class ModelProvider {
*ai:ModelProvider;
private final http:Client ollamaClient;
private final string modelType;
Expand Down Expand Up @@ -51,47 +51,72 @@ public isolated client class Provider {

# Sends a chat request to the Ollama model with the given messages and tools.
#
# + messages - List of chat messages
# + messages - List of chat messages or user message
# + tools - Tool definitions to be used for the tool call
# + stop - Stop sequence to stop the completion
# + return - Function to be called, chat response or an error in-case of failures
isolated remote function chat(ai:ChatMessage[] messages, ai:ChatCompletionFunctions[] tools = [], string? stop = ())
returns ai:ChatAssistantMessage|ai:LlmError {
isolated remote function chat(ai:ChatMessage[]|ai:ChatUserMessage messages, ai:ChatCompletionFunctions[] tools = [],
string? stop = ()) returns ai:ChatAssistantMessage|ai:Error {
// Ollama chat completion API reference: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
json requestPayload = self.prepareRequestPayload(messages, tools, stop);
json requestPayload = check self.prepareRequestPayload(messages, tools, stop);
OllamaResponse|error response = self.ollamaClient->/api/chat.post(requestPayload);
if response is error {
return error ai:LlmConnectionError("Error while connecting to ollama", response);
}
return self.mapOllamaResponseToAssistantMessage(response);
}

private isolated function prepareRequestPayload(ai:ChatMessage[] messages, ai:ChatCompletionFunctions[] tools,
string? stop) returns json {
json[] transformedMessages = messages.'map(isolated function(ai:ChatMessage message) returns json {
if message is ai:ChatFunctionMessage {
return {role: TOOL_ROLE, content: message?.content};
}
return message;
});

private isolated function prepareRequestPayload(ai:ChatMessage[]|ai:ChatUserMessage messages,
ai:ChatCompletionFunctions[] tools, string? stop) returns json|ai:Error {
map<json> options = {...self.modleParameters};
if stop is string {
options["stop"] = [stop];
}

map<json> payload = {
model: self.modelType,
messages: transformedMessages,
messages: check self.mapToOllamaRequestMessage(messages),
'stream: false,
options
};
if tools.length() > 0 {
payload["tools"] = tools.'map(tool => {'type: ai:FUNCTION, 'function: tool});
payload["tools"] = tools.'map(tool => {'type: FUNCTION, 'function: tool});
}
return payload;
}

private isolated function mapToOllamaRequestMessage(ai:ChatMessage[]|ai:ChatUserMessage messages)
returns json[]|ai:Error {
json[] transformedMessages = [];
if messages is ai:ChatUserMessage {
transformedMessages.push({
role: TOOL_ROLE,
content: check getChatMessageStringContent(messages?.content)
});
return transformedMessages;
}
foreach ai:ChatMessage message in messages {
if message is ai:ChatFunctionMessage {
transformedMessages.push({role: TOOL_ROLE, content: message?.content});

} else if message is ai:ChatUserMessage {
transformedMessages.push({
role: ai:USER,
content: check getChatMessageStringContent(message.content)
});

} else if message is ai:ChatSystemMessage {
transformedMessages.push({
role: ai:SYSTEM,
content: check getChatMessageStringContent(message.content)
});
} else if message is ai:ChatAssistantMessage {
transformedMessages.push(message);
}
}
return transformedMessages;
}

private isolated function mapOllamaResponseToAssistantMessage(OllamaResponse response)
returns ai:ChatAssistantMessage {
OllamaToolCall[]? toolCalls = response.message?.tool_calls;
Expand All @@ -110,6 +135,9 @@ public isolated client class Provider {
};
return {role: ai:ASSISTANT, toolCalls};
}

// TODO
isolated remote function generate(ai:Prompt prompt, typedesc<anydata> td = <>) returns td|ai:Error = external;
}

isolated function getModelParameterMap(OllamaModelParameters modleParameters) returns readonly & map<json>|ai:Error {
Expand All @@ -121,3 +149,44 @@ isolated function getModelParameterMap(OllamaModelParameters modleParameters) re
return error ai:Error("Error while processing model parameters", e);
}
}

isolated function getChatMessageStringContent(ai:Prompt|string prompt) returns string|ai:Error {
if prompt is string {
return prompt;
}
string[] & readonly strings = prompt.strings;
anydata[] insertions = prompt.insertions;
string promptStr = strings[0];
foreach int i in 0 ..< insertions.length() {
string str = strings[i + 1];
anydata insertion = insertions[i];

if insertion is ai:TextDocument|ai:TextChunk {
promptStr += insertion.content + " " + str;
continue;
}

if insertion is ai:TextDocument[] {
foreach ai:TextDocument doc in insertion {
promptStr += doc.content + " ";
}
promptStr += str;
continue;
}

if insertion is ai:TextChunk[] {
foreach ai:TextChunk doc in insertion {
promptStr += doc.content + " ";
}
promptStr += str;
continue;
}

if insertion is ai:Document {
return error ai:Error("Only Text Documents are currently supported.");
}

promptStr += insertion.toString() + str;
}
return promptStr.trim();
}
2 changes: 2 additions & 0 deletions ballerina/types.bal
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,5 @@ type OllamaFunction record {
string name;
map<json> arguments;
};

const FUNCTION = "function";
4 changes: 2 additions & 2 deletions build-config/resources/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ authors = ["Ballerina"]
distribution = "2201.12.0"
keywords = ["AI", "Agent", "Ollama", "Model", "Provider"]
license = ["Apache-2.0"]
name = "ai.model.provider.ollama"
name = "ai.ollama"
org = "ballerinax"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.model.provider.ollama"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.ollama"
version = "@toml.version@"

[platform.java21]
Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ allprojects {
def moduleVersion = project.version.replace("-SNAPSHOT", "")

task build {
dependsOn(':ai.model.provider.ollama-ballerina:build')
dependsOn(':ai.ollama-ballerina:build')
}

release {
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
org.gradle.caching=true
group=io.ballerina.lib
version=1.0.1-SNAPSHOT
version=1.0.0-SNAPSHOT
ballerinaLangVersion=2201.12.0

shadowJarPluginVersion=8.1.1
Expand Down
6 changes: 3 additions & 3 deletions settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ plugins {
id "com.gradle.enterprise" version "3.2"
}

rootProject.name = 'module-ballerinax-ai.model.provider.ollama'
rootProject.name = 'module-ballerinax-ai.ollama'

include ':ai.model.provider.ollama-ballerina'
include ':ai.ollama-ballerina'

project(':ai.model.provider.ollama-ballerina').projectDir = file("ballerina")
project(':ai.ollama-ballerina').projectDir = file("ballerina")

gradleEnterprise {
buildScan {
Expand Down
Loading