From 80e9671cfce50912d3c13d8738b233b48ed49385 Mon Sep 17 00:00:00 2001 From: Thevakumar-Luheerathan Date: Thu, 24 Jul 2025 14:17:26 +0530 Subject: [PATCH 1/3] [Automated] Update the toml files --- ballerina/Dependencies.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index db97044..7222cb0 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -5,12 +5,12 @@ [ballerina] dependencies-toml-version = "2" -distribution-version = "2201.12.0" +distribution-version = "2201.12.7" [[package]] org = "ballerina" name = "ai" -version = "1.1.0" +version = "1.1.1" dependencies = [ {org = "ballerina", name = "constraint"}, {org = "ballerina", name = "data.jsondata"}, From 056405108989968f522a4e07cda667236f886b9b Mon Sep 17 00:00:00 2001 From: Thevakumar-Luheerathan Date: Thu, 24 Jul 2025 14:35:29 +0530 Subject: [PATCH 2/3] Improve support for runtime JSON schema generation --- ballerina/tests/test_utils.bal | 118 +++++++++++++ ballerina/tests/test_values.bal | 165 ++++++++++++++++++ ballerina/tests/tests.bal | 86 ++++++++- .../GenerateMethodModificationTask.java | 52 +++--- .../java/io/ballerina/lib/ollama/Native.java | 52 ++++++ 5 files changed, 444 insertions(+), 29 deletions(-) diff --git a/ballerina/tests/test_utils.bal b/ballerina/tests/test_utils.bal index 6a0a9ab..d91fb76 100644 --- a/ballerina/tests/test_utils.bal +++ b/ballerina/tests/test_utils.bal @@ -91,6 +91,38 @@ isolated function getExpectedParameterSchema(string message) returns map { }; } + if message.startsWith("Give me a random joke about cricketers") { + return expectedParameterSchemaForRecUnionBasicType; + } + + if message.startsWith("Give me a random joke") { + return {"type":"object","properties":{"result":{"anyOf":[{"type":"string"},{"type":"null"}]}}}; + } + + if message.startsWith("Name a random world class cricketer in India") { + return expectedParameterSchemaForRecUnionNull; + } + + if message.startsWith("Name 10 world class cricketers in India") { + return expectedParameterSchemaForArrayOnly; + } + + if message.startsWith("Name 10 world class cricketers as string") { + return expectedParameterSchemaForArrayUnionBasicType; + } + + if message.startsWith("Name top 10 world class cricketers") { + return expectedParameterSchemaForArrayUnionRec; + } + + if message.startsWith("Name a random world class cricketer") { + return expectedParameterSchemaForArrayUnionRec; + } + + if message.startsWith("Name 10 world class cricketers") { + return expectedParamSchemaForArrayUnionNull; + } + return {}; } @@ -167,6 +199,52 @@ isolated function getTheMockLLMResult(string message) returns map { } } + if message.startsWith("Name a random world class cricketer in India") { + return {"result": {"name": "Sanga"}}; + } + + if message.startsWith("Name a random world class cricketer") { + return {"result": {"name": "Sanga"}}; + } + + if message.startsWith("Name 10 world class cricketers") { + return { + "result": [ + {"name": "Virat Kohli"}, + {"name": "Joe Root"}, + {"name": "Steve Smith"}, + {"name": "Kane Williamson"}, + {"name": "Babar Azam"}, + {"name": "Ben Stokes"}, + {"name": "Jasprit Bumrah"}, + {"name": "Pat Cummins"}, + {"name": "Shaheen Afridi"}, + {"name": "Rashid Khan"} + ] + }; + } + + if message.startsWith("Name top 10 world class cricketers") { + return { + "result": [ + {"name": "Virat Kohli"}, + {"name": "Joe Root"}, + {"name": "Steve Smith"}, + {"name": "Kane Williamson"}, + {"name": "Babar Azam"}, + {"name": "Ben Stokes"}, + {"name": "Jasprit Bumrah"}, + {"name": "Pat Cummins"}, + {"name": "Shaheen Afridi"}, + {"name": "Rashid Khan"} + ] + }; + } + + if message.startsWith("Give me a random joke") { + return {"result": "This is a random joke"}; + } + return {}; } @@ -246,5 +324,45 @@ isolated function getExpectedPrompt(string message) returns string { their name?`; } + if message.startsWith("Name 10 world class cricketers in India") { + return "Name 10 world class cricketers in India\nYou must call the `getResults`" + + " tool to obtain the correct answer."; + } + + if message.startsWith("Name 10 world class cricketers as string") { + return "Name 10 world class cricketers as string\nYou must call the `getResults`" + + " tool to obtain the correct answer."; + } + + if message.startsWith("Name 10 world class cricketers") { + return "Name 10 world class cricketers\nYou must call the `getResults`" + + " tool to obtain the correct answer."; + } + + if message.startsWith("Name top 10 world class cricketers") { + return "Name top 10 world class cricketers\nYou must call the `getResults`" + + " tool to obtain the correct answer."; + } + + if message.startsWith("Name a random world class cricketer in India") { + return "Name a random world class cricketer in India\nYou must call the `getResults`" + + " tool to obtain the correct answer."; + } + + if message.startsWith("Name a random world class cricketer") { + return "Name a random world class cricketer\nYou must call the `getResults`" + + " tool to obtain the correct answer."; + } + + if message.startsWith("Give me a random joke about cricketers") { + return "Give me a random joke about cricketers\nYou must call the `getResults`" + + " tool to obtain the correct answer."; + } + + if message.startsWith("Give me a random joke") { + return "Give me a random joke\nYou must call the `getResults`" + + " tool to obtain the correct answer."; + } + return "INVALID"; } diff --git a/ballerina/tests/test_values.bal b/ballerina/tests/test_values.bal index cfafc89..0a9c5dd 100644 --- a/ballerina/tests/test_values.bal +++ b/ballerina/tests/test_values.bal @@ -167,3 +167,168 @@ const expectedParamterSchemaStringForBalProgram = const expectedParamterSchemaStringForCountry = {"type": "object", "properties": {"result": {"type": "string"}}}; + + + +const expectedParamSchemaForArrayUnionNull = + { + "type": "object", + "properties": { + "result": { + "anyOf": [ + { + "type": "array", + "items": { + "required": [ + "name" + ], + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + }, + { + "type": "null" + } + ] + } + } + }; + +const expectedParameterSchemaForArrayUnionRec = + { + "type": "object", + "properties": { + "result": { + "anyOf": [ + { + "type": "array", + "items": { + "required": [ + "name" + ], + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + }, + { + "required": [ + "name" + ], + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + ] + } + } + }; + + const expectedParameterSchemaForArrayUnionBasicType = + { + "type": "object", + "properties": { + "result": { + "anyOf": [ + { + "type": "array", + "items": { + "required": [ + "name" + ], + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + }, + { + "type": "string" + } + ] + } + } + }; + +const expectedParameterSchemaForArrayOnly = + { + "type": "object", + "properties": { + "result": { + "type": "array", + "items": { + "required": [ + "name" + ], + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + } + } + }; + +const expectedParameterSchemaForRecUnionBasicType = + { + "type": "object", + "properties": { + "result": { + "anyOf": [ + { + "required": [ + "name" + ], + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + { + "type": "string" + } + ] + } + } + }; + +const expectedParameterSchemaForRecUnionNull = + { + "type": "object", + "properties": { + "result": { + "anyOf": [ + { + "required": [ + "name" + ], + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + { + "type": "null" + } + ] + } + } + }; diff --git a/ballerina/tests/tests.bal b/ballerina/tests/tests.bal index 99e4691..9307a59 100644 --- a/ballerina/tests/tests.bal +++ b/ballerina/tests/tests.bal @@ -136,7 +136,7 @@ type ProductName record {| @test:Config function testGenerateMethodWithInvalidRecordType() returns ai:Error? { - ProductName[]|error rating = trap ollamaProvider->generate( + ProductName[]|map|error rating = trap ollamaProvider->generate( `Tell me name and the age of the top 10 world class cricketers`); string msg = (rating).message(); test:assertTrue(rating is error); @@ -153,3 +153,87 @@ function testGenerateMethodWithInvalidRecordArrayType2() returns ai:Error? { test:assertTrue(rating is error); test:assertTrue((rating).message().includes(ERROR_MESSAGE)); } + +type Cricketers record {| + string name; +|}; + +type Cricketers1 record {| + string name; +|}; + +type Cricketers2 record {| + string name; +|}; + +type Cricketers3 record {| + string name; +|}; + +type Cricketers4 record {| + string name; +|}; + +type Cricketers5 record {| + string name; +|}; + +type Cricketers6 record {| + string name; +|}; + +type Cricketers7 record {| + string name; +|}; + +type Cricketers8 record {| + string name; +|}; + +@test:Config +function testGenerateMethodWithStringUnionNull() returns error? { + string? result = check ollamaProvider->generate(`Give me a random joke`); + test:assertTrue(result is string); +} + +@test:Config +function testGenerateMethodWithRecUnionBasicType() returns error? { + Cricketers|string result = check ollamaProvider->generate(`Give me a random joke about cricketers`); + test:assertTrue(result is string); +} + +@test:Config +function testGenerateMethodWithRecUnionNull() returns error? { + Cricketers1? result = check ollamaProvider->generate(`Name a random world class cricketer in India`); + test:assertTrue(result is Cricketers1); +} + +@test:Config +function testGenerateMethodWithArrayOnly() returns error? { + Cricketers2[] result = check ollamaProvider->generate(`Name 10 world class cricketers in India`); + test:assertTrue(result is Cricketers2[]); +} + +@test:Config +function testGenerateMethodWithArrayUnionBasicType() returns error? { + Cricketers3[]|string result = check ollamaProvider->generate(`Name 10 world class cricketers as string`); + test:assertTrue(result is Cricketers3[]); +} + +@test:Config +function testGenerateMethodWithArrayUnionNull() returns error? { + Cricketers4[]? result = check ollamaProvider->generate(`Name 10 world class cricketers`); + test:assertTrue(result is Cricketers4[]); +} + +@test:Config +function testGenerateMethodWithArrayUnionRecord() returns ai:Error? { + Cricketers5[]|Cricketers6|error result = ollamaProvider->generate(`Name top 10 world class cricketers`); + test:assertTrue(result is Cricketers5[]); +} + +@test:Config +function testGenerateMethodWithArrayUnionRecord2() returns ai:Error? { + Cricketers7[]|Cricketers8|error result = ollamaProvider->generate(`Name a random world class cricketer`); + test:assertTrue(result is Cricketers8); +} diff --git a/compiler-plugin/src/main/java/io/ballerina/lib/ai/ollama/GenerateMethodModificationTask.java b/compiler-plugin/src/main/java/io/ballerina/lib/ai/ollama/GenerateMethodModificationTask.java index e29ac2d..9a25180 100644 --- a/compiler-plugin/src/main/java/io/ballerina/lib/ai/ollama/GenerateMethodModificationTask.java +++ b/compiler-plugin/src/main/java/io/ballerina/lib/ai/ollama/GenerateMethodModificationTask.java @@ -20,10 +20,13 @@ import io.ballerina.compiler.api.ModuleID; import io.ballerina.compiler.api.SemanticModel; +import io.ballerina.compiler.api.symbols.ArrayTypeSymbol; import io.ballerina.compiler.api.symbols.ClassSymbol; import io.ballerina.compiler.api.symbols.ErrorTypeSymbol; import io.ballerina.compiler.api.symbols.ModuleSymbol; +import io.ballerina.compiler.api.symbols.RecordTypeSymbol; import io.ballerina.compiler.api.symbols.Symbol; +import io.ballerina.compiler.api.symbols.TupleTypeSymbol; import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol; import io.ballerina.compiler.api.symbols.TypeSymbol; import io.ballerina.compiler.api.symbols.UnionTypeSymbol; @@ -239,38 +242,31 @@ public void visit(RemoteMethodCallActionNode remoteMethodCallActionNode) { } private void updateTypeSchemaForTypeDef(RemoteMethodCallActionNode remoteMethodCallActionNode) { - semanticModel.typeOf(remoteMethodCallActionNode).ifPresent(expTypeSymbol -> { - updateTypeSchema(expTypeSymbol); - }); + semanticModel.typeOf(remoteMethodCallActionNode).ifPresent(symbol -> populateTypeSchema(symbol, + this.typeMapper, modifierData.typeSchemas, this.semanticModel.types().ANYDATA)); } - private void updateTypeSchema(TypeSymbol expTypeSymbol) { - if (expTypeSymbol instanceof UnionTypeSymbol expTypeUnionSymbol) { - TypeSymbol nonErrorTypeSymbol = null; - TypeSymbol typeRefTypeSymbol = null; - List memberTypeSymbols = expTypeUnionSymbol.memberTypeDescriptors(); - for (TypeSymbol memberTypeSymbol: memberTypeSymbols) { - if (memberTypeSymbol instanceof TypeReferenceTypeSymbol typeReferenceTypeSymbol) { - typeRefTypeSymbol = typeReferenceTypeSymbol.typeDescriptor(); - } - - if (!(typeRefTypeSymbol instanceof ErrorTypeSymbol)) { - nonErrorTypeSymbol = memberTypeSymbol; + private static void populateTypeSchema(TypeSymbol memberType, TypeMapper typeMapper, + Map typeSchemas, TypeSymbol anydataType) { + switch (memberType) { + case TypeReferenceTypeSymbol typeReference -> { + if (!typeReference.subtypeOf(anydataType)) { + return; } + typeSchemas.put(typeReference.definition().getName().get(), + getJsonSchema(typeMapper.getSchema(typeReference))); } - - if (!(nonErrorTypeSymbol instanceof TypeReferenceTypeSymbol)) { - return; - } - populateTypeSchema(nonErrorTypeSymbol, typeMapper, modifierData.typeSchemas); - } - } - - private static void populateTypeSchema(TypeSymbol memberType, TypeMapper typeMapper, - Map typeSchemas) { - if (Objects.requireNonNull(memberType) instanceof TypeReferenceTypeSymbol typeReference) { - typeSchemas.put(typeReference.definition().getName().get(), - getJsonSchema(typeMapper.getSchema(typeReference))); + case ArrayTypeSymbol arrayType -> + populateTypeSchema(arrayType.memberTypeDescriptor(), typeMapper, typeSchemas, anydataType); + case TupleTypeSymbol tupleType -> + tupleType.members().forEach(member -> + populateTypeSchema(member.typeDescriptor(), typeMapper, typeSchemas, anydataType)); + case RecordTypeSymbol recordType -> + recordType.fieldDescriptors().values().forEach(field -> + populateTypeSchema(field.typeDescriptor(), typeMapper, typeSchemas, anydataType)); + case UnionTypeSymbol unionTypeSymbol -> unionTypeSymbol.memberTypeDescriptors().forEach(member -> + populateTypeSchema(member, typeMapper, typeSchemas, anydataType)); + default -> { } } } diff --git a/native/src/main/java/io/ballerina/lib/ollama/Native.java b/native/src/main/java/io/ballerina/lib/ollama/Native.java index 0844438..7323d12 100644 --- a/native/src/main/java/io/ballerina/lib/ollama/Native.java +++ b/native/src/main/java/io/ballerina/lib/ollama/Native.java @@ -19,17 +19,24 @@ import io.ballerina.runtime.api.creators.ErrorCreator; import io.ballerina.runtime.api.creators.TypeCreator; import io.ballerina.runtime.api.creators.ValueCreator; +import io.ballerina.runtime.api.types.AnnotatableType; import io.ballerina.runtime.api.types.ArrayType; import io.ballerina.runtime.api.types.JsonType; import io.ballerina.runtime.api.types.PredefinedTypes; +import io.ballerina.runtime.api.types.ReferenceType; import io.ballerina.runtime.api.types.Type; +import io.ballerina.runtime.api.types.TypeTags; +import io.ballerina.runtime.api.types.UnionType; import io.ballerina.runtime.api.utils.StringUtils; import io.ballerina.runtime.api.utils.TypeUtils; +import io.ballerina.runtime.api.values.BArray; import io.ballerina.runtime.api.values.BError; import io.ballerina.runtime.api.values.BMap; import io.ballerina.runtime.api.values.BString; import io.ballerina.runtime.api.values.BTypedesc; +import java.util.List; + import static io.ballerina.runtime.api.creators.ValueCreator.createMapValue; /** @@ -38,6 +45,10 @@ * @since 1.0.0 */ public class Native { + public static final String ANY_OF = "anyOf"; + public static final String BALLERINA_AI = "ballerina/ai"; + public static final String JSON_SCHEMA = "JsonSchema"; + public static Object generateJsonSchemaForTypedescNative(BTypedesc td) { SchemaGenerationContext schemaGenerationContext = new SchemaGenerationContext(); try { @@ -58,6 +69,9 @@ private static Object generateJsonSchemaForType(Type t, SchemaGenerationContext return switch (impliedType) { case JsonType ignored -> generateJsonSchemaForJson(); case ArrayType arrayType -> generateJsonSchemaForArrayType(arrayType, schemaGenerationContext); + case UnionType unionType -> generateUnionTypeSchema(unionType, schemaGenerationContext); + case ReferenceType referenceType -> getJsonSchemaFromAnnotatableType(referenceType, + schemaGenerationContext); default -> throw ErrorCreator.createError(StringUtils.fromString( "Runtime schema generation is not yet supported for type: " + impliedType.getName())); }; @@ -91,7 +105,45 @@ private static boolean isSimpleType(Type type) { return type.getBasicType().all() <= 0b100000; } + private static Object generateUnionTypeSchema(UnionType unionType, + SchemaGenerationContext schemaGenerationContext) { + BMap schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); + List memberTypes = unionType.getMemberTypes(); + BArray schemas = ValueCreator.createArrayValue( + TypeCreator.createArrayType(PredefinedTypes.TYPE_JSON)); + for (Type memberType : memberTypes) { + Object schema = generateJsonSchemaForType(memberType, schemaGenerationContext); + schemas.append(schema); + } + if (schemas.size() == 1) { + return schemas.get(0); + } + schemaMap.put(StringUtils.fromString(ANY_OF), schemas); + return schemaMap; + } + + private static Object getJsonSchemaFromAnnotatableType(ReferenceType referenceType, + SchemaGenerationContext schemaGenerationContext) { + Type referredType = referenceType.getReferredType(); + if (referredType instanceof AnnotatableType annotatableType) { + BMap annotations = annotatableType.getAnnotations(); + for (BString key : annotations.getKeys()) { + if (key.getValue().startsWith(BALLERINA_AI) && key.getValue().endsWith(JSON_SCHEMA)) { + Object schema = annotations.get(key); + if (schema instanceof BMap) { + return schema; + } + } + } + } + throw ErrorCreator.createError(StringUtils.fromString( + "Runtime schema generation is not yet supported for type: " + referenceType.getName())); + } + private static String getStringRepresentation(Type type) { + if (type.getTag() == TypeTags.NULL_TAG) { + return "null"; + } return switch (type.getBasicType().all()) { case 0b000000 -> "null"; case 0b000010 -> "boolean"; From f75fcb0050c847d758c92aa67785b2bc19bc26c7 Mon Sep 17 00:00:00 2001 From: Thevakumar-Luheerathan Date: Thu, 24 Jul 2025 16:50:07 +0530 Subject: [PATCH 3/3] Fix build failure --- ballerina/tests/test_values.bal | 2 +- .../ballerina/lib/ai/ollama/GenerateMethodModificationTask.java | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/ballerina/tests/test_values.bal b/ballerina/tests/test_values.bal index 0a9c5dd..0dfec83 100644 --- a/ballerina/tests/test_values.bal +++ b/ballerina/tests/test_values.bal @@ -234,7 +234,7 @@ const expectedParameterSchemaForArrayUnionRec = } }; - const expectedParameterSchemaForArrayUnionBasicType = +const expectedParameterSchemaForArrayUnionBasicType = { "type": "object", "properties": { diff --git a/compiler-plugin/src/main/java/io/ballerina/lib/ai/ollama/GenerateMethodModificationTask.java b/compiler-plugin/src/main/java/io/ballerina/lib/ai/ollama/GenerateMethodModificationTask.java index 9a25180..be2802c 100644 --- a/compiler-plugin/src/main/java/io/ballerina/lib/ai/ollama/GenerateMethodModificationTask.java +++ b/compiler-plugin/src/main/java/io/ballerina/lib/ai/ollama/GenerateMethodModificationTask.java @@ -22,7 +22,6 @@ import io.ballerina.compiler.api.SemanticModel; import io.ballerina.compiler.api.symbols.ArrayTypeSymbol; import io.ballerina.compiler.api.symbols.ClassSymbol; -import io.ballerina.compiler.api.symbols.ErrorTypeSymbol; import io.ballerina.compiler.api.symbols.ModuleSymbol; import io.ballerina.compiler.api.symbols.RecordTypeSymbol; import io.ballerina.compiler.api.symbols.Symbol; @@ -69,7 +68,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import static io.ballerina.projects.util.ProjectConstants.EMPTY_STRING;