diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 7222cb0..8f6f92a 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -5,7 +5,7 @@ [ballerina] dependencies-toml-version = "2" -distribution-version = "2201.12.7" +distribution-version = "2201.12.0" [[package]] org = "ballerina" @@ -65,6 +65,9 @@ version = "1.7.0" dependencies = [ {org = "ballerina", name = "jballerina.java"} ] +modules = [ + {org = "ballerina", packageName = "constraint", moduleName = "constraint"} +] [[package]] org = "ballerina" @@ -189,6 +192,9 @@ dependencies = [ {org = "ballerina", name = "jballerina.java"}, {org = "ballerina", name = "lang.__internal"} ] +modules = [ + {org = "ballerina", packageName = "lang.array", moduleName = "lang.array"} +] [[package]] org = "ballerina" @@ -391,9 +397,11 @@ name = "ai.ollama" version = "1.0.1" dependencies = [ {org = "ballerina", name = "ai"}, + {org = "ballerina", name = "constraint"}, {org = "ballerina", name = "data.jsondata"}, {org = "ballerina", name = "http"}, {org = "ballerina", name = "jballerina.java"}, + {org = "ballerina", name = "lang.array"}, {org = "ballerina", name = "test"} ] modules = [ diff --git a/ballerina/provider_utils.bal b/ballerina/provider_utils.bal index 46bcf90..d858d8b 100644 --- a/ballerina/provider_utils.bal +++ b/ballerina/provider_utils.bal @@ -15,6 +15,7 @@ // under the License. import ballerina/ai; +import ballerina/constraint; import ballerina/http; type ResponseSchema record {| @@ -92,35 +93,57 @@ isolated function getGetResultsTool(map parameters) returns map[]|er isolated function generateChatCreationContent(ai:Prompt prompt) returns string|ai:Error { string[] & readonly strings = prompt.strings; anydata[] insertions = prompt.insertions; - string promptStr = strings[0]; - foreach int i in 0 ..< insertions.length() { - string str = strings[i + 1]; - anydata insertion = insertions[i]; + string promptStr = ""; - if insertion is ai:TextDocument { - promptStr += insertion.content + " " + str; - continue; - } + if strings.length() > 0 { + promptStr += strings[0]; + } - if insertion is ai:TextDocument[] { - foreach ai:TextDocument doc in insertion { - promptStr += doc.content + " "; - - } - promptStr += str; - continue; - } + foreach int i in 0 ..< insertions.length() { + anydata insertion = insertions[i]; + string str = strings[i + 1]; if insertion is ai:Document { - return error ai:Error("Only Text Documents are currently supported."); + if insertion is ai:TextDocument { + promptStr += insertion.content + " "; + } else if insertion is ai:ImageDocument { + promptStr += check addImageContentPart(insertion); + } else { + return error ai:Error("Only Text and Image Documents are currently supported."); + } + } else if insertion is ai:Document[] { + foreach ai:Document doc in insertion { + if doc is ai:TextDocument { + promptStr += doc.content + " "; + } else if doc is ai:ImageDocument { + promptStr += check addImageContentPart(doc); + } else { + return error ai:Error("Only Text and Image Documents are currently supported."); + } + } + } else { + promptStr += insertion.toString(); } - - promptStr += insertion.toString() + str; + promptStr += str; } + promptStr += addToolDirective(); return promptStr.trim(); } +isolated function addImageContentPart(ai:ImageDocument doc) returns string|ai:Error { + ai:Url|byte[] content = doc.content; + if content is ai:Url { + ai:Url|constraint:Error validationRes = constraint:validate(content); + if validationRes is error { + return error(validationRes.message(), validationRes.cause()); + } + return string ` ${content} `; + } + + return string ` ${content.toBase64()} `; +} + isolated function addToolDirective() returns string { return "\nYou must call the `getResults` tool to obtain the correct answer."; } diff --git a/ballerina/tests/test_utils.bal b/ballerina/tests/test_utils.bal index d91fb76..b6155d3 100644 --- a/ballerina/tests/test_utils.bal +++ b/ballerina/tests/test_utils.bal @@ -51,6 +51,18 @@ isolated function getExpectedParameterSchema(string message) returns map { return expectedParameterSchemaStringForRateBlog2; } + if message.startsWith("Describe the following image.") { + return expectedParameterSchemaStringForRateBlog9; + } + + if message.startsWith("Describe this image.") { + return expectedParameterSchemaStringForRateBlog9; + } + + if message.startsWith("Describe these images.") { + return expectedParameterSchemaStringForRateBlog8; + } + if message.startsWith("How do you rate this blog") { return expectedParameterSchemaStringForRateBlog7; } @@ -199,6 +211,18 @@ isolated function getTheMockLLMResult(string message) returns map { } } + if message.startsWith("Describe the following image.") { + return {"result": "This is a sample image description."}; + } + + if message.startsWith("Describe this image.") { + return {"result": "This is a sample image description."}; + } + + if message.startsWith("Describe these images.") { + return {"result": ["This is a sample image description.", "This is a sample image description."]}; + } + if message.startsWith("Name a random world class cricketer in India") { return {"result": {"name": "Sanga"}}; } @@ -324,6 +348,18 @@ isolated function getExpectedPrompt(string message) returns string { their name?`; } + if message.startsWith("Describe the following image.") { + return string `Describe the following image. ${sampleStringData} .${"\n"}You must call the ${"`"}getResults${"`"} tool to obtain the correct answer.`; + } + + if message.startsWith("Describe this image.") { + return "Describe this image. https://example.com/sample-image.jpg .\nYou must call the `getResults` tool to obtain the correct answer."; + } + + if message.startsWith("Describe these images.") { + return string `Describe these images. ${sampleStringData} https://example.com/sample-image.jpg .${"\n"}You must call the ${"`"}getResults${"`"} tool to obtain the correct answer.`; + } + if message.startsWith("Name 10 world class cricketers in India") { return "Name 10 world class cricketers in India\nYou must call the `getResults`" + " tool to obtain the correct answer."; diff --git a/ballerina/tests/test_values.bal b/ballerina/tests/test_values.bal index 0dfec83..065e042 100644 --- a/ballerina/tests/test_values.bal +++ b/ballerina/tests/test_values.bal @@ -13,6 +13,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +import ballerina/lang.array; type Blog record { string title; @@ -24,6 +25,9 @@ type Review record {| string comment; |}; +final readonly & byte[] sampleBinaryData = [0x01, 0x02, 0x03, 0x04, 0x05]; +final readonly & string sampleStringData = array:toBase64(sampleBinaryData); + const blog1 = { // Generated. title: "Tips for Growing a Beautiful Garden", @@ -70,7 +74,7 @@ final string expectedPromptStringForRateBlog7 = final string expectedPromptStringForRateBlog8 = string `How would you rate this text blog out of 10, Title: ${blog1.title} Content: ${blog1.content} .${"\n"}You must call the ${"`"}getResults${"`"} tool to obtain the correct answer.`; -final string expectedPromptStringForRateBlog9 = string +final string expectedPromptStringForRateBlog9 = string `How would you rate this text blogs out of 10. Title: ${blog1.title} Content: ${blog1.content} Title: ${blog1.title} Content: ${blog1.content} . Thank you!${"\n"}You must call the ${"`"}getResults${"`"} tool to obtain the correct answer.`; final string expectedPromptStringForRateBlog10 = string `Evaluate this blogs out of 10. @@ -101,7 +105,7 @@ const expectedParameterSchemaStringForRateBlog = {"type": "object", "properties": {"result": {"type": "integer"}}}; const expectedParameterSchemaStringForRateBlog7 = - {"type":"object","properties":{"result":{"type":["integer", "null"]}}}; + {"type": "object", "properties": {"result": {"type": ["integer", "null"]}}}; const expectedParameterSchemaStringForRateBlog2 = { @@ -162,6 +166,22 @@ const expectedParameterSchemaStringForRateBlog6 = } }; +const expectedParameterSchemaStringForRateBlog8 = + { + "type": "object", + "properties": { + "result": { + "type": "array", + "items": { + "type": "string" + } + } + } +}; + +const expectedParameterSchemaStringForRateBlog9 = + {"type": "object", "properties": {"result": {"type": "string"}}}; + const expectedParamterSchemaStringForBalProgram = {"type": "object", "properties": {"result": {"type": "integer"}}}; diff --git a/ballerina/tests/tests.bal b/ballerina/tests/tests.bal index 9307a59..0742804 100644 --- a/ballerina/tests/tests.bal +++ b/ballerina/tests/tests.bal @@ -123,6 +123,49 @@ function testGenerateMethodWithRecordArrayReturnType() returns error? { test:assertEquals(result, [r, r]); } +@test:Config +function testGenerateMethodWithImageDocument() returns ai:Error? { + ai:ImageDocument img = { + content: sampleBinaryData + }; + + ai:ImageDocument img2 = { + content: "https://example.com/sample-image.jpg" + }; + + ai:ImageDocument img3 = { + content: "" + }; + + string|error description = ollamaProvider->generate(`Describe the following image.${img}.`); + test:assertEquals(description, "This is a sample image description."); + + description = ollamaProvider->generate(`Describe this image.${img2}.`); + test:assertEquals(description, "This is a sample image description."); + + string[]|error descriptions = ollamaProvider->generate(`Describe these images.${[img, img2]}.`); + test:assertEquals(descriptions, ["This is a sample image description.", "This is a sample image description."]); + + description = ollamaProvider->generate(`Describe this image. ${img3}.`); + if description is string { + test:assertFail(); + } + test:assertEquals(description.message(), "Must be a valid URL."); +} + +@test:Config +function testGenerateMethodWithInvalidDocument() returns ai:Error? { + ai:AudioDocument aud = { + content: sampleBinaryData + }; + + string|error description = ollamaProvider->generate(`Describe this image. ${aud}.`); + if description is string { + test:assertFail(); + } + test:assertEquals(description.message(), "Only Text and Image Documents are currently supported."); +} + @test:Config function testGenerateMethodWithInvalidBasicType() returns ai:Error? { boolean|error rating = ollamaProvider->generate(`What is ${1} + ${1}?`);