Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ballerina/main.bal
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ isolated function getPromptWithExpectedResponseSchema(string prompt, map<json> e
Schema:
${expectedResponseSchema.toJsonString()}`;

isolated function callLlmGeneric(Prompt prompt, Context context, typedesc<json> targetType) returns json|error {
isolated function callLlmGeneric(Prompt prompt, Context context, typedesc<json> targetType,
map<json>? jsonSchema) returns json|error {
Model model = context.model;
json resp =
check model->call(buildPromptString(prompt), generateJsonSchemaForTypedescAsJson(targetType));
json resp =
check model->call(buildPromptString(prompt), jsonSchema ?: generateJsonSchemaForTypedescAsJson(targetType));
return parseResponseAsType(resp, targetType);
}

Expand Down
45 changes: 44 additions & 1 deletion native/src/main/java/io/ballerina/lib/np/Native.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

import io.ballerina.runtime.api.Environment;
import io.ballerina.runtime.api.Module;
import io.ballerina.runtime.api.creators.TypeCreator;
import io.ballerina.runtime.api.creators.ValueCreator;
import io.ballerina.runtime.api.flags.SymbolFlags;
import io.ballerina.runtime.api.types.AnnotatableType;
import io.ballerina.runtime.api.types.ArrayType;
import io.ballerina.runtime.api.types.Field;
import io.ballerina.runtime.api.types.PredefinedTypes;
import io.ballerina.runtime.api.types.RecordType;
import io.ballerina.runtime.api.types.Type;
import io.ballerina.runtime.api.types.TypeTags;
Expand All @@ -30,8 +33,14 @@
import io.ballerina.runtime.api.values.BArray;
import io.ballerina.runtime.api.values.BMap;
import io.ballerina.runtime.api.values.BObject;
import io.ballerina.runtime.api.values.BString;
import io.ballerina.runtime.api.values.BTypedesc;

import java.util.List;
import java.util.Map;

import static io.ballerina.runtime.api.creators.ValueCreator.createMapValue;

/**
* Native implementation of natural programming functions.
*
Expand All @@ -40,9 +49,43 @@
public class Native {
public static Object callLlm(Environment env, BObject prompt, BMap context, BTypedesc targetType) {
return env.getRuntime().callFunction(
new Module("ballerinax", "np", "0"), "callLlmGeneric", null, prompt, context, targetType);
new Module("ballerinax", "np", "0"), "callLlmGeneric", null, prompt, context, targetType,
generateJsonSchemaForUnionType(targetType));
}

public static Object generateJsonSchemaForUnionType(BTypedesc td) {
Type type = td.getDescribingType();

if (type instanceof UnionType bUnionType) {
List<Type> memberTypes = bUnionType.getMemberTypes();

for (Type bType : memberTypes) {
bType = TypeUtils.getReferredType(bType);

if (bType instanceof AnnotatableType annotatableType) {
BMap<BString, Object> schemaMap =
createMapValue(TypeCreator.createMapType(PredefinedTypes.TYPE_JSON));
schemaMap.put(StringUtils.fromString("type"), StringUtils.fromString("object"));
BMap<BString, Object> annotations = annotatableType.getAnnotations();
BArray annotationArray =
ValueCreator.createArrayValue(TypeCreator.createArrayType(PredefinedTypes.TYPE_JSON));
int index = 0;
for (Map.Entry<BString, Object> entry : annotations.entrySet()) {
if (entry.getKey().getValue().equals("ballerinax/np:0:Schema")) {
annotationArray.add(index++, entry.getValue());
}
}

schemaMap.put(StringUtils.fromString("anyOf"), annotationArray);
return schemaMap;
}
}
}

return null;
}


// Simple, simple, SIMPLE implementation for now.
public static void populateFieldInfo(BTypedesc typedesc, BArray names, BArray required,
BArray types, BArray nilable) {
Expand Down
Loading