diff --git a/ballerina/main.bal b/ballerina/main.bal index bac4f8e..bf60ef5 100644 --- a/ballerina/main.bal +++ b/ballerina/main.bal @@ -91,10 +91,11 @@ isolated function getPromptWithExpectedResponseSchema(string prompt, map e Schema: ${expectedResponseSchema.toJsonString()}`; -isolated function callLlmGeneric(Prompt prompt, Context context, typedesc targetType) returns json|error { +isolated function callLlmGeneric(Prompt prompt, Context context, typedesc targetType, + map? jsonSchema) returns json|error { Model model = context.model; - json resp = - check model->call(buildPromptString(prompt), generateJsonSchemaForTypedescAsJson(targetType)); + json resp = + check model->call(buildPromptString(prompt), jsonSchema ?: generateJsonSchemaForTypedescAsJson(targetType)); return parseResponseAsType(resp, targetType); } diff --git a/ballerina/tests/tests.bal b/ballerina/tests/tests.bal index 9b3ced8..13e1784 100644 --- a/ballerina/tests/tests.bal +++ b/ballerina/tests/tests.bal @@ -47,11 +47,10 @@ function testJsonConversionError() { test:assertTrue(( rating).message().includes(ERROR_MESSAGE)); } +type Foo record{| string name; |}; @test:Config function testJsonConversionError2() { - record{| - string name; - |}[]|error rating = callLlm(`Tell me name and the age of the top 10 world class cricketers`); + Foo[]|error rating = callLlm(`Tell me name and the age of the top 10 world class cricketers`); test:assertTrue(rating is error); test:assertTrue(( rating).message().includes(ERROR_MESSAGE)); } diff --git a/compiler-plugin/src/main/java/io/ballerina/lib/np/compilerplugin/PromptAsCodeCodeModificationTask.java b/compiler-plugin/src/main/java/io/ballerina/lib/np/compilerplugin/PromptAsCodeCodeModificationTask.java index 1cb01ff..8c0476b 100644 --- a/compiler-plugin/src/main/java/io/ballerina/lib/np/compilerplugin/PromptAsCodeCodeModificationTask.java +++ b/compiler-plugin/src/main/java/io/ballerina/lib/np/compilerplugin/PromptAsCodeCodeModificationTask.java @@ -110,7 +110,7 @@ public class PromptAsCodeCodeModificationTask implements ModifierTask typeSymbol = + semanticModel.expectedType(document, functionCallExpressionNode.lineRange().startLine()); + if (typeSymbol.isEmpty()) { + return functionCallExpressionNode; + } + getTypeSchema(typeSymbol.get(), analysisData.typeMapper, modifierData.typeSchemas); + return functionCallExpressionNode; + } + } + private static class FunctionModifier extends TreeModifier { private final ModifierData modifierData; @@ -437,22 +476,21 @@ private void extractAndStoreSchemas(SemanticModel semanticModel, FunctionDefinit } } - private void getTypeSchema(TypeSymbol memberType, TypeMapper typeMapper, Map typeSchemas) { + private static void getTypeSchema(TypeSymbol memberType, TypeMapper typeMapper, Map typeSchemas) { switch (memberType) { case TypeReferenceTypeSymbol typeReference -> typeSchemas.put(typeReference.definition().getName().get(), getJsonSchema(typeMapper.getSchema(typeReference))); - case ArrayTypeSymbol arrayType -> - getTypeSchema(arrayType.memberTypeDescriptor(), typeMapper, typeSchemas); - + getTypeSchema(arrayType.memberTypeDescriptor(), typeMapper, typeSchemas); case TupleTypeSymbol tupleType -> tupleType.members().forEach(member -> getTypeSchema(member.typeDescriptor(), typeMapper, typeSchemas)); - case RecordTypeSymbol recordType -> recordType.fieldDescriptors().values().forEach(field -> getTypeSchema(field.typeDescriptor(), typeMapper, typeSchemas)); + case UnionTypeSymbol unionTypeSymbol -> unionTypeSymbol.memberTypeDescriptors().forEach(member -> + getTypeSchema(member, typeMapper, typeSchemas)); default -> { } } } diff --git a/native/src/main/java/io/ballerina/lib/np/Native.java b/native/src/main/java/io/ballerina/lib/np/Native.java index 1adea1b..f8e1d60 100644 --- a/native/src/main/java/io/ballerina/lib/np/Native.java +++ b/native/src/main/java/io/ballerina/lib/np/Native.java @@ -17,11 +17,15 @@ import io.ballerina.runtime.api.Environment; import io.ballerina.runtime.api.Module; +import io.ballerina.runtime.api.creators.TypeCreator; import io.ballerina.runtime.api.creators.ValueCreator; import io.ballerina.runtime.api.flags.SymbolFlags; import io.ballerina.runtime.api.types.ArrayType; import io.ballerina.runtime.api.types.Field; +import io.ballerina.runtime.api.types.JsonType; +import io.ballerina.runtime.api.types.PredefinedTypes; import io.ballerina.runtime.api.types.RecordType; +import io.ballerina.runtime.api.types.TupleType; import io.ballerina.runtime.api.types.Type; import io.ballerina.runtime.api.types.TypeTags; import io.ballerina.runtime.api.types.UnionType; @@ -30,17 +34,122 @@ import io.ballerina.runtime.api.values.BArray; import io.ballerina.runtime.api.values.BMap; import io.ballerina.runtime.api.values.BObject; +import io.ballerina.runtime.api.values.BString; import io.ballerina.runtime.api.values.BTypedesc; +import java.util.Map; + +import static io.ballerina.runtime.api.creators.ValueCreator.createMapValue; + /** * Native implementation of natural programming functions. * * @since 0.3.0 */ public class Native { + + static Boolean isSchemaGeneratedAtCompileTime; + public static Object callLlm(Environment env, BObject prompt, BMap context, BTypedesc targetType) { + isSchemaGeneratedAtCompileTime = true; + Object jsonSchema = generateJsonSchemaForType(targetType.getDescribingType()); return env.getRuntime().callFunction( - new Module("ballerinax", "np", "0"), "callLlmGeneric", null, prompt, context, targetType); + new Module("ballerinax", "np", "0"), "callLlmGeneric", null, prompt, context, targetType, + isSchemaGeneratedAtCompileTime ? jsonSchema : null); + } + + public static Object generateJsonSchemaForType(Type td) { + Type type = TypeUtils.getReferredType(td); + if (isSimpleType(type)) { + return createSimpleTypeSchema(type); + } + + return switch (type) { + case RecordType recordType -> generateJsonSchemaForRecordType(recordType); + case JsonType ignored -> generateJsonSchemaForJson(); + case ArrayType arrayType -> generateJsonSchemaForArrayType(arrayType); + case TupleType tupleType -> generateJsonSchemaForTupleType(tupleType); + case UnionType unionType -> generateJsonSchemaForUnionType(unionType); + default -> null; + }; + } + + private static BMap createSimpleTypeSchema(Type type) { + BMap schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); + schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString(getStringRepresentation(type))); + return schemaMap; + } + + private static BMap generateJsonSchemaForJson() { + BString[] bStringValues = new BString[6]; + bStringValues[0] = StringUtils.fromString("object"); + bStringValues[1] = StringUtils.fromString("array"); + bStringValues[2] = StringUtils.fromString("string"); + bStringValues[3] = StringUtils.fromString("number"); + bStringValues[4] = StringUtils.fromString("boolean"); + bStringValues[5] = StringUtils.fromString("null"); + BMap schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); + schemaMap.put(StringUtils.fromString("type"), ValueCreator.createArrayValue(bStringValues)); + return schemaMap; + } + + private static Object generateJsonSchemaForArrayType(ArrayType arrayType) { + BMap schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); + Type elementType = TypeUtils.getReferredType(arrayType.getElementType()); + schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("array")); + schemaMap.put(StringUtils.fromString("items"), generateJsonSchemaForType(elementType)); + return schemaMap; + } + + private static Object generateJsonSchemaForTupleType(TupleType tupleType) { + BMap schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); + schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("array")); + BArray annotationArray = ValueCreator.createArrayValue(TypeCreator.createArrayType(PredefinedTypes.TYPE_JSON)); + int index = 0; + for (Type type : tupleType.getTupleTypes()) { + annotationArray.add(index++, generateJsonSchemaForType(type)); + } + schemaMap.put(StringUtils.fromString("items"), annotationArray); + return schemaMap; + } + + private static boolean isSimpleType(Type type) { + return type.getBasicType().all() <= 0b100000; + } + + private static String getStringRepresentation(Type type) { + return switch (type.getBasicType().all()) { + case 0b000000 -> "null"; + case 0b000010 -> "boolean"; + case 0b000100 -> "integer"; + case 0b001000, 0b010000 -> "number"; + case 0b100000 -> "string"; + default -> null; + }; + } + + private static Object generateJsonSchemaForRecordType(RecordType recordType) { + for (Map.Entry entry : recordType.getAnnotations().entrySet()) { + if ("ballerinax/np:0:Schema".equals(entry.getKey().getValue())) { + return entry.getValue(); + } + } + isSchemaGeneratedAtCompileTime = false; + return null; + } + + private static Object generateJsonSchemaForUnionType(UnionType unionType) { + BMap schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); + schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("object")); + BArray annotationArray = ValueCreator.createArrayValue(TypeCreator.createArrayType(PredefinedTypes.TYPE_JSON)); + + int index = 0; + for (Type type : unionType.getMemberTypes()) { + annotationArray.add(index++, generateJsonSchemaForType(type)); + } + + schemaMap.put(StringUtils.fromString("anyOf"), annotationArray); + return schemaMap; } // Simple, simple, SIMPLE implementation for now.