Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ icon="icon.png"
name = "ai.openai"
org = "ballerinax"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.openai"
version = "1.2.0"
version = "1.2.1"

[platform.java21]
graalvmCompatible = true

[[platform.java21.dependency]]
groupId = "io.ballerina.lib"
artifactId = "ai.openai-native"
version = "1.2.0"
path = "../native/build/libs/ai.openai-native-1.2.0.jar"
version = "1.2.1"
path = "../native/build/libs/ai.openai-native-1.2.1-SNAPSHOT.jar"
2 changes: 1 addition & 1 deletion ballerina/CompilerPlugin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ id = "ai-compiler-plugin"
class = "io.ballerina.lib.ai.openai.AiOpenAICompilerPlugin"

[[dependency]]
path = "../compiler-plugin/build/libs/ai.openai-compiler-plugin-1.2.0.jar"
path = "../compiler-plugin/build/libs/ai.openai-compiler-plugin-1.2.1-SNAPSHOT.jar"

[[dependency]]
path = "../compiler-plugin/build/libs/ballerina-to-openapi-2.3.0.jar"
4 changes: 2 additions & 2 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ distribution-version = "2201.12.0"
[[package]]
org = "ballerina"
name = "ai"
version = "1.1.1"
version = "1.1.2"
dependencies = [
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "data.jsondata"},
Expand Down Expand Up @@ -406,7 +406,7 @@ dependencies = [
[[package]]
org = "ballerinax"
name = "ai.openai"
version = "1.2.0"
version = "1.2.1"
dependencies = [
{org = "ballerina", name = "ai"},
{org = "ballerina", name = "constraint"},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2025 WSO2 LLC. (http://www.wso2.org).
//
// WSO2 Inc. licenses this file to you under the Apache License,
// Version 2.0 (the "License"); you may not use this file except
// in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import ballerina/test;

@test:Config
function testGenerateMethodUsageInFileWithoutAiImport() returns error? {
Review|error result = provider->generate(`Please rate this blog out of ${"10"}.
Title: ${blog2.title}
Content: ${blog2.content}`);
test:assertEquals(result, reviewRecord);
}
5 changes: 5 additions & 0 deletions ballerina/tests/test_values.bal
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ const sampleImageUrl = "https://example.com/image.jpg";

const review = "{\"rating\": 8, \"comment\": \"Talks about essential aspects of sports performance " +
"including warm-up, form, equipment, and nutrition.\"}";

const reviewRecord = {
rating: 8,
comment: "Talks about essential aspects of sports performance including warm-up, form, equipment, and nutrition."
};

final readonly & map<anydata>[] expectedContentPartsForRateBlog = [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@

package io.ballerina.lib.ai.openai;

import io.ballerina.compiler.api.ModuleID;
import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.Types;
import io.ballerina.compiler.api.symbols.ArrayTypeSymbol;
import io.ballerina.compiler.api.symbols.ClassSymbol;
import io.ballerina.compiler.api.symbols.ModuleSymbol;
import io.ballerina.compiler.api.symbols.RecordTypeSymbol;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.TupleTypeSymbol;
Expand Down Expand Up @@ -81,6 +80,10 @@
class GenerateMethodModificationTask implements ModifierTask<SourceModifierContext> {
private static final String AI_MODULE_NAME = "ai";
private static final String BALLERINA_ORG_NAME = "ballerina";
private static final String OPENAI_MODEL_PROVIDER_NAME = "ModelProvider";
private static final String OPENAI_MODEL_PROVIDER_MODULE_NAME = "ai.openai";
private static final String OPENAI_MODEL_PROVIDER_MODULE_VERSION = "1";
private static final String OPENAI_MODEL_PROVIDER_MODULE_ORG = "ballerinax";
private final AiOpenAICodeModifier.AnalysisData analysisData;
private final ModifierData modifierData;

Expand All @@ -103,12 +106,17 @@ public void modify(SourceModifierContext modifierContext) {
Collection<DocumentId> documentIds = module.documentIds();
Collection<DocumentId> testDocumentIds = module.testDocumentIds();

Types types = semanticModel.types();
Optional<Symbol> openAiModelProviderSymbol =
types.getTypeByName(OPENAI_MODEL_PROVIDER_MODULE_ORG, OPENAI_MODEL_PROVIDER_MODULE_NAME,
OPENAI_MODEL_PROVIDER_MODULE_VERSION, OPENAI_MODEL_PROVIDER_NAME);

for (DocumentId documentId : documentIds) {
analyzeDocument(module, documentId, semanticModel);
analyzeDocument(module, documentId, semanticModel, openAiModelProviderSymbol);
}

for (DocumentId documentId : testDocumentIds) {
analyzeDocument(module, documentId, semanticModel);
analyzeDocument(module, documentId, semanticModel, openAiModelProviderSymbol);
}

for (DocumentId documentId : documentIds) {
Expand All @@ -123,14 +131,15 @@ public void modify(SourceModifierContext modifierContext) {
}
}

private void analyzeDocument(Module module, DocumentId documentId, SemanticModel semanticModel) {
private void analyzeDocument(Module module, DocumentId documentId, SemanticModel semanticModel,
Optional<Symbol> openAiModelProviderSymbol) {
Document document = module.document(documentId);
Node rootNode = document.syntaxTree().rootNode();
if (!(rootNode instanceof ModulePartNode modulePartNode)) {
return;
}

analyzeGenerateMethod(document, semanticModel, modulePartNode, this.analysisData);
analyzeGenerateMethod(semanticModel, modulePartNode, openAiModelProviderSymbol, this.analysisData);
}

private static TextDocument modifyDocument(Document document, ModifierData modifierData) {
Expand All @@ -157,9 +166,11 @@ private static ImportDeclarationNode createImportDeclarationForAIModule() {
return NodeParser.parseImportDeclaration(String.format("import %s/%s;", BALLERINA_ORG_NAME, AI_MODULE_NAME));
}

private void analyzeGenerateMethod(Document document, SemanticModel semanticModel,
ModulePartNode modulePartNode, AiOpenAICodeModifier.AnalysisData analysisData) {
new GenerateMethodJsonSchemaGenerator(semanticModel, document, analysisData).generate(modulePartNode);
private void analyzeGenerateMethod(SemanticModel semanticModel,
ModulePartNode modulePartNode, Optional<Symbol> openAiModelProviderSymbol,
AiOpenAICodeModifier.AnalysisData analysisData) {
new GenerateMethodJsonSchemaGenerator(semanticModel, openAiModelProviderSymbol, analysisData)
.generate(modulePartNode);
}

private static String getAiModuleImportPrefix(NodeList<ImportDeclarationNode> imports) {
Expand Down Expand Up @@ -197,24 +208,29 @@ private static String getAiModuleImportPrefix(NodeList<ImportDeclarationNode> im

private class GenerateMethodJsonSchemaGenerator extends NodeVisitor {
private static final String GENERATE_METHOD_NAME = "generate";
private static final String OPENAI_MODEL_PROVIDER_NAME = "ModelProvider";
private static final String OPENAI_MODEL_PROVIDER_MODULE_NAME = "ai.openai";
private static final String OPENAI_MODEL_PROVIDER_MODULE_VERSION = "1";
private static final String OPENAI_MODEL_PROVIDER_MODULE_ORG = "ballerinax";
private static final String STRING = "string";
private static final String BYTE = "byte";
private static final String NUMBER = "number";
private final SemanticModel semanticModel;
private final Document document;
private final TypeMapper typeMapper;
private final ClassSymbol openaiProviderSymbol;

public GenerateMethodJsonSchemaGenerator(SemanticModel semanticModel, Document document,
public GenerateMethodJsonSchemaGenerator(SemanticModel semanticModel,
Optional<Symbol> openAiModelProviderSymbolOpt,
AiOpenAICodeModifier.AnalysisData analyserData) {
this.semanticModel = semanticModel;
this.document = document;
this.typeMapper = analyserData.typeMapper;
this.openaiProviderSymbol = getOpenAIProviderSymbol(document.syntaxTree().rootNode()).orElse(null);
if (openAiModelProviderSymbolOpt.isEmpty()) {
this.openaiProviderSymbol = null;
return;
}

Symbol openAiModelProviderSymbol = openAiModelProviderSymbolOpt.get();
if (openAiModelProviderSymbol instanceof ClassSymbol openAiModelProviderClassSymbol) {
this.openaiProviderSymbol = openAiModelProviderClassSymbol;
} else {
this.openaiProviderSymbol = null;
}
}

void generate(ModulePartNode modulePartNode) {
Expand Down Expand Up @@ -268,37 +284,6 @@ private static void populateTypeSchema(TypeSymbol memberType, TypeMapper typeMap
}
}

private Optional<ClassSymbol> getOpenAIProviderSymbol(Node node) {
Optional<ModuleSymbol> openAiModuleSymbol = getOpenAIModuleSymbol(node);
if (openAiModuleSymbol.isEmpty()) {
return Optional.empty();
}

for (ClassSymbol classSymbol: openAiModuleSymbol.get().classes()) {
if (classSymbol.nameEquals(OPENAI_MODEL_PROVIDER_NAME)) {
return Optional.of(classSymbol);
}
}

return Optional.empty();
}

private Optional<ModuleSymbol> getOpenAIModuleSymbol(Node node) {
for (Symbol symbol : semanticModel.visibleSymbols(this.document, node.lineRange().startLine())) {
if (!(symbol instanceof ModuleSymbol moduleSymbol)) {
continue;
}

ModuleID id = moduleSymbol.id();
if (OPENAI_MODEL_PROVIDER_MODULE_ORG.equals(id.orgName())
&& OPENAI_MODEL_PROVIDER_MODULE_NAME.equals(id.moduleName())
&& id.version().startsWith(OPENAI_MODEL_PROVIDER_MODULE_VERSION)) {
return Optional.of(moduleSymbol);
}
}
return Optional.empty();
}

private static String getJsonSchema(Schema schema) {
modifySchema(schema);
OpenAPISchema2JsonSchema openAPISchema2JsonSchema = new OpenAPISchema2JsonSchema();
Expand Down
Loading