-
Notifications
You must be signed in to change notification settings - Fork 4
Add runtime Json schema generation #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b15b83c
b535d5f
4f0bea7
ad335a9
c78b2c0
dae5175
bdfa730
a1d8dde
d39a920
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,7 +110,7 @@ public class PromptAsCodeCodeModificationTask implements ModifierTask<SourceModi | |
| NodeFactory.createSimpleNameReferenceNode(NodeFactory.createIdentifierToken(CONTEXT_VAR)); | ||
|
|
||
| private final ModifierData modifierData; | ||
| private final CodeModifier.AnalysisData analysisData; | ||
| private static CodeModifier.AnalysisData analysisData; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this changed? Why can't this be used similar to modifierData? Note the warnings also - Static member 'io.ballerina.lib.np.compilerplugin.PromptAsCodeCodeModificationTask.analysisData' accessed via instance reference |
||
|
|
||
| PromptAsCodeCodeModificationTask(CodeModifier.AnalysisData analysisData) { | ||
| this.modifierData = new ModifierData(); | ||
|
|
@@ -136,12 +136,20 @@ public void modify(SourceModifierContext modifierContext) { | |
|
|
||
| for (DocumentId documentId: module.documentIds()) { | ||
| Document document = module.document(documentId); | ||
| modifierContext.modifySourceFile(modifyDocument(document, modifierData), documentId); | ||
| modifierContext.modifySourceFile(modifyDocument(document, modifierData, modifierContext, moduleId), | ||
| documentId); | ||
| } | ||
|
|
||
| for (DocumentId documentId: module.testDocumentIds()) { | ||
| Document document = module.document(documentId); | ||
| modifierContext.modifyTestSourceFile(modifyDocument(document, modifierData), documentId); | ||
| processImportDeclarations(document, modifierData); | ||
| processExternalFunctions(document, module, modifierData, modifierContext); | ||
| } | ||
|
|
||
| for (DocumentId documentId: module.testDocumentIds()) { | ||
| Document document = module.document(documentId); | ||
| modifierContext.modifyTestSourceFile(modifyDocument(document, modifierData, modifierContext, moduleId), | ||
| documentId); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -171,12 +179,16 @@ private static void processImportDeclarations(Document document, ModifierData mo | |
| modulePartNode.apply(importDeclarationModifier); | ||
| } | ||
|
|
||
| private static TextDocument modifyDocument(Document document, ModifierData modifierData) { | ||
| private static TextDocument modifyDocument(Document document, ModifierData modifierData, | ||
| SourceModifierContext modifierContext, ModuleId moduleId) { | ||
| ModulePartNode modulePartNode = document.syntaxTree().rootNode(); | ||
| FunctionModifier functionModifier = new FunctionModifier(modifierData); | ||
| FunctionCallModifier functionCallModifier = | ||
| new FunctionCallModifier(modifierData, modifierContext, moduleId, document); | ||
| TypeDefinitionModifier typeDefinitionModifier = new TypeDefinitionModifier(modifierData.typeSchemas, | ||
| modifierData); | ||
|
|
||
| modulePartNode.apply(functionCallModifier); | ||
| ModulePartNode modifiedRoot = (ModulePartNode) modulePartNode.apply(functionModifier); | ||
| modifiedRoot = modifiedRoot.modify(modifiedRoot.imports(), modifiedRoot.members(), modifiedRoot.eofToken()); | ||
|
|
||
|
|
@@ -214,6 +226,33 @@ public ImportDeclarationNode transform(ImportDeclarationNode importDeclarationNo | |
| } | ||
| } | ||
|
|
||
| private static class FunctionCallModifier extends TreeModifier { | ||
| private final ModifierData modifierData; | ||
| private final SourceModifierContext modifierContext; | ||
| private final ModuleId moduleId; | ||
| private final Document document; | ||
|
|
||
| FunctionCallModifier(ModifierData modifierData, SourceModifierContext modifierContext, ModuleId moduleId, | ||
| Document document) { | ||
| this.modifierData = modifierData; | ||
| this.modifierContext = modifierContext; | ||
| this.moduleId = moduleId; | ||
| this.document = document; | ||
| } | ||
|
|
||
| @Override | ||
| public FunctionCallExpressionNode transform(FunctionCallExpressionNode functionCallExpressionNode) { | ||
| SemanticModel semanticModel = modifierContext.compilation().getSemanticModel(moduleId); | ||
| Optional<TypeSymbol> typeSymbol = | ||
| semanticModel.expectedType(document, functionCallExpressionNode.lineRange().startLine()); | ||
| if (typeSymbol.isEmpty()) { | ||
| return functionCallExpressionNode; | ||
| } | ||
| getTypeSchema(typeSymbol.get(), analysisData.typeMapper, modifierData.typeSchemas); | ||
| return functionCallExpressionNode; | ||
| } | ||
| } | ||
|
|
||
| private static class FunctionModifier extends TreeModifier { | ||
|
|
||
| private final ModifierData modifierData; | ||
|
|
@@ -437,22 +476,21 @@ private void extractAndStoreSchemas(SemanticModel semanticModel, FunctionDefinit | |
| } | ||
| } | ||
|
|
||
| private void getTypeSchema(TypeSymbol memberType, TypeMapper typeMapper, Map<String, String> typeSchemas) { | ||
| private static void getTypeSchema(TypeSymbol memberType, TypeMapper typeMapper, Map<String, String> typeSchemas) { | ||
| switch (memberType) { | ||
| case TypeReferenceTypeSymbol typeReference -> | ||
| typeSchemas.put(typeReference.definition().getName().get(), | ||
| getJsonSchema(typeMapper.getSchema(typeReference))); | ||
|
|
||
| case ArrayTypeSymbol arrayType -> | ||
| getTypeSchema(arrayType.memberTypeDescriptor(), typeMapper, typeSchemas); | ||
|
|
||
| getTypeSchema(arrayType.memberTypeDescriptor(), typeMapper, typeSchemas); | ||
| case TupleTypeSymbol tupleType -> | ||
| tupleType.members().forEach(member -> | ||
| getTypeSchema(member.typeDescriptor(), typeMapper, typeSchemas)); | ||
|
|
||
| case RecordTypeSymbol recordType -> | ||
| recordType.fieldDescriptors().values().forEach(field -> | ||
| getTypeSchema(field.typeDescriptor(), typeMapper, typeSchemas)); | ||
| case UnionTypeSymbol unionTypeSymbol -> unionTypeSymbol.memberTypeDescriptors().forEach(member -> | ||
| getTypeSchema(member, typeMapper, typeSchemas)); | ||
| default -> { } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,15 @@ | |
|
|
||
| import io.ballerina.runtime.api.Environment; | ||
| import io.ballerina.runtime.api.Module; | ||
| import io.ballerina.runtime.api.creators.TypeCreator; | ||
| import io.ballerina.runtime.api.creators.ValueCreator; | ||
| import io.ballerina.runtime.api.flags.SymbolFlags; | ||
| import io.ballerina.runtime.api.types.ArrayType; | ||
| import io.ballerina.runtime.api.types.Field; | ||
| import io.ballerina.runtime.api.types.JsonType; | ||
| import io.ballerina.runtime.api.types.PredefinedTypes; | ||
| import io.ballerina.runtime.api.types.RecordType; | ||
| import io.ballerina.runtime.api.types.TupleType; | ||
| import io.ballerina.runtime.api.types.Type; | ||
| import io.ballerina.runtime.api.types.TypeTags; | ||
| import io.ballerina.runtime.api.types.UnionType; | ||
|
|
@@ -30,17 +34,122 @@ | |
| import io.ballerina.runtime.api.values.BArray; | ||
| import io.ballerina.runtime.api.values.BMap; | ||
| import io.ballerina.runtime.api.values.BObject; | ||
| import io.ballerina.runtime.api.values.BString; | ||
| import io.ballerina.runtime.api.values.BTypedesc; | ||
|
|
||
| import java.util.Map; | ||
|
|
||
| import static io.ballerina.runtime.api.creators.ValueCreator.createMapValue; | ||
|
|
||
| /** | ||
| * Native implementation of natural programming functions. | ||
| * | ||
| * @since 0.3.0 | ||
| */ | ||
| public class Native { | ||
|
|
||
| static Boolean isSchemaGeneratedAtCompileTime; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is this correct when this function can get called concurrently?! |
||
|
|
||
| public static Object callLlm(Environment env, BObject prompt, BMap context, BTypedesc targetType) { | ||
| isSchemaGeneratedAtCompileTime = true; | ||
| Object jsonSchema = generateJsonSchemaForType(targetType.getDescribingType()); | ||
| return env.getRuntime().callFunction( | ||
| new Module("ballerinax", "np", "0"), "callLlmGeneric", null, prompt, context, targetType); | ||
| new Module("ballerinax", "np", "0"), "callLlmGeneric", null, prompt, context, targetType, | ||
| isSchemaGeneratedAtCompileTime ? jsonSchema : null); | ||
| } | ||
|
|
||
| public static Object generateJsonSchemaForType(Type td) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could have added tests for this runtime schema generation. |
||
| Type type = TypeUtils.getReferredType(td); | ||
| if (isSimpleType(type)) { | ||
| return createSimpleTypeSchema(type); | ||
| } | ||
|
|
||
| return switch (type) { | ||
| case RecordType recordType -> generateJsonSchemaForRecordType(recordType); | ||
| case JsonType ignored -> generateJsonSchemaForJson(); | ||
| case ArrayType arrayType -> generateJsonSchemaForArrayType(arrayType); | ||
| case TupleType tupleType -> generateJsonSchemaForTupleType(tupleType); | ||
| case UnionType unionType -> generateJsonSchemaForUnionType(unionType); | ||
| default -> null; | ||
| }; | ||
| } | ||
|
|
||
| private static BMap<BString, Object> createSimpleTypeSchema(Type type) { | ||
| BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); | ||
| schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString(getStringRepresentation(type))); | ||
| return schemaMap; | ||
| } | ||
|
|
||
| private static BMap<BString, Object> generateJsonSchemaForJson() { | ||
| BString[] bStringValues = new BString[6]; | ||
| bStringValues[0] = StringUtils.fromString("object"); | ||
| bStringValues[1] = StringUtils.fromString("array"); | ||
| bStringValues[2] = StringUtils.fromString("string"); | ||
| bStringValues[3] = StringUtils.fromString("number"); | ||
| bStringValues[4] = StringUtils.fromString("boolean"); | ||
| bStringValues[5] = StringUtils.fromString("null"); | ||
| BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); | ||
| schemaMap.put(StringUtils.fromString("type"), ValueCreator.createArrayValue(bStringValues)); | ||
| return schemaMap; | ||
| } | ||
|
|
||
| private static Object generateJsonSchemaForArrayType(ArrayType arrayType) { | ||
| BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); | ||
| Type elementType = TypeUtils.getReferredType(arrayType.getElementType()); | ||
| schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("array")); | ||
| schemaMap.put(StringUtils.fromString("items"), generateJsonSchemaForType(elementType)); | ||
| return schemaMap; | ||
| } | ||
|
|
||
| private static Object generateJsonSchemaForTupleType(TupleType tupleType) { | ||
| BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); | ||
| schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("array")); | ||
| BArray annotationArray = ValueCreator.createArrayValue(TypeCreator.createArrayType(PredefinedTypes.TYPE_JSON)); | ||
| int index = 0; | ||
| for (Type type : tupleType.getTupleTypes()) { | ||
| annotationArray.add(index++, generateJsonSchemaForType(type)); | ||
| } | ||
| schemaMap.put(StringUtils.fromString("items"), annotationArray); | ||
| return schemaMap; | ||
| } | ||
|
|
||
| private static boolean isSimpleType(Type type) { | ||
| return type.getBasicType().all() <= 0b100000; | ||
| } | ||
|
|
||
| private static String getStringRepresentation(Type type) { | ||
| return switch (type.getBasicType().all()) { | ||
| case 0b000000 -> "null"; | ||
| case 0b000010 -> "boolean"; | ||
| case 0b000100 -> "integer"; | ||
| case 0b001000, 0b010000 -> "number"; | ||
| case 0b100000 -> "string"; | ||
| default -> null; | ||
| }; | ||
| } | ||
|
|
||
| private static Object generateJsonSchemaForRecordType(RecordType recordType) { | ||
| for (Map.Entry<BString, Object> entry : recordType.getAnnotations().entrySet()) { | ||
| if ("ballerinax/np:0:Schema".equals(entry.getKey().getValue())) { | ||
| return entry.getValue(); | ||
| } | ||
| } | ||
| isSchemaGeneratedAtCompileTime = false; | ||
| return null; | ||
| } | ||
|
|
||
| private static Object generateJsonSchemaForUnionType(UnionType unionType) { | ||
| BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON)); | ||
| schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("object")); | ||
| BArray annotationArray = ValueCreator.createArrayValue(TypeCreator.createArrayType(PredefinedTypes.TYPE_JSON)); | ||
|
|
||
| int index = 0; | ||
| for (Type type : unionType.getMemberTypes()) { | ||
| annotationArray.add(index++, generateJsonSchemaForType(type)); | ||
| } | ||
|
|
||
| schemaMap.put(StringUtils.fromString("anyOf"), annotationArray); | ||
| return schemaMap; | ||
| } | ||
|
|
||
| // Simple, simple, SIMPLE implementation for now. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use an appropriate name instead of Foo.