Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ballerina/main.bal
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ isolated function getPromptWithExpectedResponseSchema(string prompt, map<json> e
Schema:
${expectedResponseSchema.toJsonString()}`;

isolated function callLlmGeneric(Prompt prompt, Context context, typedesc<json> targetType) returns json|error {
isolated function callLlmGeneric(Prompt prompt, Context context, typedesc<json> targetType,
map<json>? jsonSchema) returns json|error {
Model model = context.model;
json resp =
check model->call(buildPromptString(prompt), generateJsonSchemaForTypedescAsJson(targetType));
json resp =
check model->call(buildPromptString(prompt), jsonSchema ?: generateJsonSchemaForTypedescAsJson(targetType));
return parseResponseAsType(resp, targetType);
}

Expand Down
5 changes: 2 additions & 3 deletions ballerina/tests/tests.bal
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ function testJsonConversionError() {
test:assertTrue((<error> rating).message().includes(ERROR_MESSAGE));
}

type Foo record{| string name; |};
@test:Config
function testJsonConversionError2() {
record{|
string name;
|}[]|error rating = callLlm(`Tell me name and the age of the top 10 world class cricketers`);
Foo[]|error rating = callLlm(`Tell me name and the age of the top 10 world class cricketers`);
test:assertTrue(rating is error);
test:assertTrue((<error> rating).message().includes(ERROR_MESSAGE));
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public class PromptAsCodeCodeModificationTask implements ModifierTask<SourceModi
NodeFactory.createSimpleNameReferenceNode(NodeFactory.createIdentifierToken(CONTEXT_VAR));

private final ModifierData modifierData;
private final CodeModifier.AnalysisData analysisData;
private static CodeModifier.AnalysisData analysisData;

PromptAsCodeCodeModificationTask(CodeModifier.AnalysisData analysisData) {
this.modifierData = new ModifierData();
Expand All @@ -136,12 +136,20 @@ public void modify(SourceModifierContext modifierContext) {

for (DocumentId documentId: module.documentIds()) {
Document document = module.document(documentId);
modifierContext.modifySourceFile(modifyDocument(document, modifierData), documentId);
modifierContext.modifySourceFile(modifyDocument(document, modifierData, modifierContext, moduleId),
documentId);
}

for (DocumentId documentId: module.testDocumentIds()) {
Document document = module.document(documentId);
modifierContext.modifyTestSourceFile(modifyDocument(document, modifierData), documentId);
processImportDeclarations(document, modifierData);
processExternalFunctions(document, module, modifierData, modifierContext);
}

for (DocumentId documentId: module.testDocumentIds()) {
Document document = module.document(documentId);
modifierContext.modifyTestSourceFile(modifyDocument(document, modifierData, modifierContext, moduleId),
documentId);
}
}
}
Expand Down Expand Up @@ -171,12 +179,16 @@ private static void processImportDeclarations(Document document, ModifierData mo
modulePartNode.apply(importDeclarationModifier);
}

private static TextDocument modifyDocument(Document document, ModifierData modifierData) {
private static TextDocument modifyDocument(Document document, ModifierData modifierData,
SourceModifierContext modifierContext, ModuleId moduleId) {
ModulePartNode modulePartNode = document.syntaxTree().rootNode();
FunctionModifier functionModifier = new FunctionModifier(modifierData);
FunctionCallModifier functionCallModifier =
new FunctionCallModifier(modifierData, modifierContext, moduleId, document);
TypeDefinitionModifier typeDefinitionModifier = new TypeDefinitionModifier(modifierData.typeSchemas,
modifierData);

modulePartNode.apply(functionCallModifier);
ModulePartNode modifiedRoot = (ModulePartNode) modulePartNode.apply(functionModifier);
modifiedRoot = modifiedRoot.modify(modifiedRoot.imports(), modifiedRoot.members(), modifiedRoot.eofToken());

Expand Down Expand Up @@ -214,6 +226,33 @@ public ImportDeclarationNode transform(ImportDeclarationNode importDeclarationNo
}
}

private static class FunctionCallModifier extends TreeModifier {
private final ModifierData modifierData;
private final SourceModifierContext modifierContext;
private final ModuleId moduleId;
private final Document document;

FunctionCallModifier(ModifierData modifierData, SourceModifierContext modifierContext, ModuleId moduleId,
Document document) {
this.modifierData = modifierData;
this.modifierContext = modifierContext;
this.moduleId = moduleId;
this.document = document;
}

@Override
public FunctionCallExpressionNode transform(FunctionCallExpressionNode functionCallExpressionNode) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why didn't we do this only for np:callLlm calls?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean to special case only for np:callLlm, right? There's no issue with the current implementation, right?it's to analyze all function calls.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we unnecessarily generating JSON schema with function calls that are not relevant to NP?

E.g.,

import ballerinax/np; // used elsewhere in the file, but no reference to IntArray

type IntArray int[];

function getArr() returns IntArray => [1, 2];

IntArray x = getArr();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SemanticModel semanticModel = modifierContext.compilation().getSemanticModel(moduleId);
Optional<TypeSymbol> typeSymbol =
semanticModel.expectedType(document, functionCallExpressionNode.lineRange().startLine());
if (typeSymbol.isEmpty()) {
return functionCallExpressionNode;
}
getTypeSchema(typeSymbol.get(), analysisData.typeMapper, modifierData.typeSchemas);
return functionCallExpressionNode;
}
}

private static class FunctionModifier extends TreeModifier {

private final ModifierData modifierData;
Expand Down Expand Up @@ -437,22 +476,21 @@ private void extractAndStoreSchemas(SemanticModel semanticModel, FunctionDefinit
}
}

private void getTypeSchema(TypeSymbol memberType, TypeMapper typeMapper, Map<String, String> typeSchemas) {
private static void getTypeSchema(TypeSymbol memberType, TypeMapper typeMapper, Map<String, String> typeSchemas) {
switch (memberType) {
case TypeReferenceTypeSymbol typeReference ->
typeSchemas.put(typeReference.definition().getName().get(),
getJsonSchema(typeMapper.getSchema(typeReference)));

case ArrayTypeSymbol arrayType ->
getTypeSchema(arrayType.memberTypeDescriptor(), typeMapper, typeSchemas);

getTypeSchema(arrayType.memberTypeDescriptor(), typeMapper, typeSchemas);
case TupleTypeSymbol tupleType ->
tupleType.members().forEach(member ->
getTypeSchema(member.typeDescriptor(), typeMapper, typeSchemas));

case RecordTypeSymbol recordType ->
recordType.fieldDescriptors().values().forEach(field ->
getTypeSchema(field.typeDescriptor(), typeMapper, typeSchemas));
case UnionTypeSymbol unionTypeSymbol -> unionTypeSymbol.memberTypeDescriptors().forEach(member ->
getTypeSchema(member, typeMapper, typeSchemas));
default -> { }
}
}
Expand Down
111 changes: 110 additions & 1 deletion native/src/main/java/io/ballerina/lib/np/Native.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

import io.ballerina.runtime.api.Environment;
import io.ballerina.runtime.api.Module;
import io.ballerina.runtime.api.creators.TypeCreator;
import io.ballerina.runtime.api.creators.ValueCreator;
import io.ballerina.runtime.api.flags.SymbolFlags;
import io.ballerina.runtime.api.types.ArrayType;
import io.ballerina.runtime.api.types.Field;
import io.ballerina.runtime.api.types.JsonType;
import io.ballerina.runtime.api.types.PredefinedTypes;
import io.ballerina.runtime.api.types.RecordType;
import io.ballerina.runtime.api.types.TupleType;
import io.ballerina.runtime.api.types.Type;
import io.ballerina.runtime.api.types.TypeTags;
import io.ballerina.runtime.api.types.UnionType;
Expand All @@ -30,17 +34,122 @@
import io.ballerina.runtime.api.values.BArray;
import io.ballerina.runtime.api.values.BMap;
import io.ballerina.runtime.api.values.BObject;
import io.ballerina.runtime.api.values.BString;
import io.ballerina.runtime.api.values.BTypedesc;

import java.util.Map;

import static io.ballerina.runtime.api.creators.ValueCreator.createMapValue;

/**
* Native implementation of natural programming functions.
*
* @since 0.3.0
*/
public class Native {

static Boolean isSchemaGeneratedAtCompileTime;

public static Object callLlm(Environment env, BObject prompt, BMap context, BTypedesc targetType) {
isSchemaGeneratedAtCompileTime = true;
Object jsonSchema = generateJsonSchemaForType(targetType.getDescribingType());
return env.getRuntime().callFunction(
new Module("ballerinax", "np", "0"), "callLlmGeneric", null, prompt, context, targetType);
new Module("ballerinax", "np", "0"), "callLlmGeneric", null, prompt, context, targetType,
isSchemaGeneratedAtCompileTime ? jsonSchema : null);
}

public static Object generateJsonSchemaForType(Type td) {
Type type = TypeUtils.getReferredType(td);
if (isSimpleType(type)) {
return createSimpleTypeSchema(type);
}

return switch (type) {
case RecordType recordType -> generateJsonSchemaForRecordType(recordType);
case JsonType ignored -> generateJsonSchemaForJson();
case ArrayType arrayType -> generateJsonSchemaForArrayType(arrayType);
case TupleType tupleType -> generateJsonSchemaForTupleType(tupleType);
case UnionType unionType -> generateJsonSchemaForUnionType(unionType);
default -> null;
};
}

private static BMap<BString, Object> createSimpleTypeSchema(Type type) {
BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON));
schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString(getStringRepresentation(type)));
return schemaMap;
}

private static BMap<BString, Object> generateJsonSchemaForJson() {
BString[] bStringValues = new BString[6];
bStringValues[0] = StringUtils.fromString("object");
bStringValues[1] = StringUtils.fromString("array");
bStringValues[2] = StringUtils.fromString("string");
bStringValues[3] = StringUtils.fromString("number");
bStringValues[4] = StringUtils.fromString("boolean");
bStringValues[5] = StringUtils.fromString("null");
BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON));
schemaMap.put(StringUtils.fromString("type"), ValueCreator.createArrayValue(bStringValues));
return schemaMap;
}

private static Object generateJsonSchemaForArrayType(ArrayType arrayType) {
BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON));
Type elementType = TypeUtils.getReferredType(arrayType.getElementType());
schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("array"));
schemaMap.put(StringUtils.fromString("items"), generateJsonSchemaForType(elementType));
return schemaMap;
}

private static Object generateJsonSchemaForTupleType(TupleType tupleType) {
BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON));
schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("array"));
BArray annotationArray = ValueCreator.createArrayValue(TypeCreator.createArrayType(PredefinedTypes.TYPE_JSON));
int index = 0;
for (Type type : tupleType.getTupleTypes()) {
annotationArray.add(index++, generateJsonSchemaForType(type));
}
schemaMap.put(StringUtils.fromString("items"), annotationArray);
return schemaMap;
}

private static boolean isSimpleType(Type type) {
return type.getBasicType().all() <= 0b100000;
}

private static String getStringRepresentation(Type type) {
return switch (type.getBasicType().all()) {
case 0b000000 -> "null";
case 0b000010 -> "boolean";
case 0b000100 -> "integer";
case 0b001000, 0b010000 -> "number";
case 0b100000 -> "string";
default -> null;
};
}

private static Object generateJsonSchemaForRecordType(RecordType recordType) {
for (Map.Entry<BString, Object> entry : recordType.getAnnotations().entrySet()) {
if ("ballerinax/np:0:Schema".equals(entry.getKey().getValue())) {
return entry.getValue();
}
}
isSchemaGeneratedAtCompileTime = false;
return null;
}

private static Object generateJsonSchemaForUnionType(UnionType unionType) {
BMap<BString, Object> schemaMap = createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON));
schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("object"));
BArray annotationArray = ValueCreator.createArrayValue(TypeCreator.createArrayType(PredefinedTypes.TYPE_JSON));

int index = 0;
for (Type type : unionType.getMemberTypes()) {
annotationArray.add(index++, generateJsonSchemaForType(type));
}

schemaMap.put(StringUtils.fromString("anyOf"), annotationArray);
return schemaMap;
}

// Simple, simple, SIMPLE implementation for now.
Expand Down
Loading