Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions ballerina/provider_utils.bal
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
// under the License.

import ballerina/ai;
import ballerina/constraint;
import ballerina/http;

type ResponseSchema record {|
map<json> schema;
boolean isOriginallyJsonObject = true;
|};

type ImageContent ai:Url|byte[];

const JSON_CONVERSION_ERROR = "FromJsonStringError";
const CONVERSION_ERROR = "ConversionError";
const ERROR_MESSAGE = "Error occurred while attempting to parse the response from the " +
Expand Down Expand Up @@ -92,35 +95,57 @@ isolated function getGetResultsTool(map<json> parameters) returns map<json>[]|er
isolated function generateChatCreationContent(ai:Prompt prompt) returns string|ai:Error {
string[] & readonly strings = prompt.strings;
anydata[] insertions = prompt.insertions;
string promptStr = strings[0];
foreach int i in 0 ..< insertions.length() {
string str = strings[i + 1];
anydata insertion = insertions[i];
string promptStr = "";

if insertion is ai:TextDocument {
promptStr += insertion.content + " " + str;
continue;
}
if strings.length() > 0 {
promptStr += strings[0];
}

if insertion is ai:TextDocument[] {
foreach ai:TextDocument doc in insertion {
promptStr += doc.content + " ";

}
promptStr += str;
continue;
}
foreach int i in 0 ..< insertions.length() {
anydata insertion = insertions[i];
string str = strings[i + 1];

if insertion is ai:Document {
return error ai:Error("Only Text Documents are currently supported.");
if insertion is ai:TextDocument {
promptStr += insertion.content + " ";
} else if insertion is ai:ImageDocument {
promptStr += check addImageContentPart(insertion);
} else {
return error ai:Error("Only Text and Image Documents are currently supported.");
}
} else if insertion is ai:Document[] {
foreach ai:Document doc in insertion {
if doc is ai:TextDocument {
promptStr += doc.content + " ";
} else if doc is ai:ImageDocument {
promptStr += check addImageContentPart(doc);
} else {
return error ai:Error("Only Text and Image Documents are currently supported.");
}
}
} else {
promptStr += insertion.toString();
}

promptStr += insertion.toString() + str;
promptStr += str;
}

promptStr += addToolDirective();
return promptStr.trim();
}

isolated function addImageContentPart(ai:ImageDocument doc) returns string|ai:Error {
ai:Url|byte[] content = doc.content;
if content is ai:Url {
ai:Url|constraint:Error validationRes = constraint:validate(content);
if validationRes is error {
return error(validationRes.message(), validationRes.cause());
}
return string ` ${content} `;
}

return string ` ${content.toBase64()} `;
}

isolated function addToolDirective() returns string {
return "\nYou must call the `getResults` tool to obtain the correct answer.";
}
Expand Down
36 changes: 36 additions & 0 deletions ballerina/tests/test_utils.bal
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ isolated function getExpectedParameterSchema(string message) returns map<json> {
return expectedParameterSchemaStringForRateBlog2;
}

if message.startsWith("Describe the following image.") {
return expectedParameterSchemaStringForRateBlog9;
}

if message.startsWith("Describe this image.") {
return expectedParameterSchemaStringForRateBlog9;
}

if message.startsWith("Describe these images.") {
return expectedParameterSchemaStringForRateBlog8;
}

if message.startsWith("How do you rate this blog") {
return expectedParameterSchemaStringForRateBlog7;
}
Expand Down Expand Up @@ -167,6 +179,18 @@ isolated function getTheMockLLMResult(string message) returns map<json> {
}
}

if message.startsWith("Describe the following image.") {
return {"result": "This is a sample image description."};
}

if message.startsWith("Describe this image.") {
return {"result": "This is a sample image description."};
}

if message.startsWith("Describe these images.") {
return {"result": ["This is a sample image description.", "This is a sample image description."]};
}

return {};
}

Expand Down Expand Up @@ -246,5 +270,17 @@ isolated function getExpectedPrompt(string message) returns string {
their name?`;
}

if message.startsWith("Describe the following image.") {
return string `Describe the following image. ${sampleStringData} .${"\n"}You must call the ${"`"}getResults${"`"} tool to obtain the correct answer.`;
}

if message.startsWith("Describe this image.") {
return "Describe this image. https://example.com/sample-image.jpg .\nYou must call the `getResults` tool to obtain the correct answer.";
}

if message.startsWith("Describe these images.") {
return string `Describe these images. ${sampleStringData} https://example.com/sample-image.jpg .${"\n"}You must call the ${"`"}getResults${"`"} tool to obtain the correct answer.`;
}

return "INVALID";
}
24 changes: 22 additions & 2 deletions ballerina/tests/test_values.bal
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
import ballerina/lang.array;

type Blog record {
string title;
Expand All @@ -24,6 +25,9 @@ type Review record {|
string comment;
|};

final readonly & byte[] sampleBinaryData = [0x01, 0x02, 0x03, 0x04, 0x05];
final readonly & string sampleStringData = array:toBase64(sampleBinaryData);

const blog1 = {
// Generated.
title: "Tips for Growing a Beautiful Garden",
Expand Down Expand Up @@ -70,7 +74,7 @@ final string expectedPromptStringForRateBlog7 =
final string expectedPromptStringForRateBlog8 =
string `How would you rate this text blog out of 10, Title: ${blog1.title} Content: ${blog1.content} .${"\n"}You must call the ${"`"}getResults${"`"} tool to obtain the correct answer.`;

final string expectedPromptStringForRateBlog9 = string
final string expectedPromptStringForRateBlog9 = string
`How would you rate this text blogs out of 10. Title: ${blog1.title} Content: ${blog1.content} Title: ${blog1.title} Content: ${blog1.content} . Thank you!${"\n"}You must call the ${"`"}getResults${"`"} tool to obtain the correct answer.`;

final string expectedPromptStringForRateBlog10 = string `Evaluate this blogs out of 10.
Expand Down Expand Up @@ -101,7 +105,7 @@ const expectedParameterSchemaStringForRateBlog =
{"type": "object", "properties": {"result": {"type": "integer"}}};

const expectedParameterSchemaStringForRateBlog7 =
{"type":"object","properties":{"result":{"type":["integer", "null"]}}};
{"type": "object", "properties": {"result": {"type": ["integer", "null"]}}};

const expectedParameterSchemaStringForRateBlog2 =
{
Expand Down Expand Up @@ -162,6 +166,22 @@ const expectedParameterSchemaStringForRateBlog6 =
}
};

const expectedParameterSchemaStringForRateBlog8 =
{
"type": "object",
"properties": {
"result": {
"type": "array",
"items": {
"type": "string"
}
}
}
};

const expectedParameterSchemaStringForRateBlog9 =
{"type": "object", "properties": {"result": {"type": "string"}}};

const expectedParamterSchemaStringForBalProgram =
{"type": "object", "properties": {"result": {"type": "integer"}}};

Expand Down
30 changes: 30 additions & 0 deletions ballerina/tests/tests.bal
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,36 @@ function testGenerateMethodWithRecordArrayReturnType() returns error? {
test:assertEquals(result, [r, r]);
}

@test:Config
function testGenerateMethodWithImageDocument() returns ai:Error? {
ai:ImageDocument img = {
content: sampleBinaryData
};

ai:ImageDocument img2 = {
content: "https://example.com/sample-image.jpg"
};

ai:ImageDocument img3 = {
content: "<invalid-url>"
};

string|error description = ollamaProvider->generate(`Describe the following image.${img}.`);
test:assertEquals(description, "This is a sample image description.");

description = ollamaProvider->generate(`Describe this image.${img2}.`);
test:assertEquals(description, "This is a sample image description.");

string[]|error descriptions = ollamaProvider->generate(`Describe these images.${<ai:ImageDocument[]>[img, img2]}.`);
test:assertEquals(descriptions, ["This is a sample image description.", "This is a sample image description."]);

description = ollamaProvider->generate(`Describe this image. ${img3}.`);
if description is string {
test:assertFail();
}
test:assertEquals(description.message(), "Must be a valid URL.");
}

@test:Config
function testGenerateMethodWithInvalidBasicType() returns ai:Error? {
boolean|error rating = ollamaProvider->generate(`What is ${1} + ${1}?`);
Expand Down
Loading