Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ license = ["Apache-2.0"]
name = "ai.weaviate"
org = "ballerinax"
repository = "https://github.com/ballerina-platform/module-ballerinax-ai.weaviate"
version = "1.0.0"
version = "1.0.1"

[platform.java21]
graalvmCompatible = true
15 changes: 12 additions & 3 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ distribution-version = "2201.12.0"
[[package]]
org = "ballerina"
name = "ai"
version = "1.5.0"
version = "1.5.2"
dependencies = [
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "data.jsondata"},
Expand All @@ -21,6 +21,7 @@ dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "lang.array"},
{org = "ballerina", name = "lang.regexp"},
{org = "ballerina", name = "lang.runtime"},
{org = "ballerina", name = "log"},
{org = "ballerina", name = "math.vector"},
{org = "ballerina", name = "mcp"},
Expand Down Expand Up @@ -107,7 +108,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "http"
version = "2.14.4"
version = "2.14.5"
dependencies = [
{org = "ballerina", name = "auth"},
{org = "ballerina", name = "cache"},
Expand Down Expand Up @@ -145,6 +146,9 @@ dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "lang.value"}
]
modules = [
{org = "ballerina", packageName = "io", moduleName = "io"}
]

[[package]]
org = "ballerina"
Expand Down Expand Up @@ -355,6 +359,9 @@ version = "2.7.0"
dependencies = [
{org = "ballerina", name = "jballerina.java"}
]
modules = [
{org = "ballerina", packageName = "time", moduleName = "time"}
]

[[package]]
org = "ballerina"
Expand Down Expand Up @@ -423,11 +430,13 @@ modules = [
[[package]]
org = "ballerinax"
name = "ai.weaviate"
version = "1.0.0"
version = "1.0.1"
dependencies = [
{org = "ballerina", name = "ai"},
{org = "ballerina", name = "http"},
{org = "ballerina", name = "io"},
{org = "ballerina", name = "test"},
{org = "ballerina", name = "time"},
{org = "ballerina", name = "uuid"},
{org = "ballerinai", name = "transaction"},
{org = "ballerinax", name = "weaviate"}
Expand Down
9 changes: 6 additions & 3 deletions ballerina/tests/test.bal
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ final VectorStore mockVectorStore = check new (
apiKey = "mock-token"
);

string fileName = "this is a test file.pdf";

@test:BeforeSuite
function beforeSuite() returns error? {
http:Client httpClient = check new ("http://localhost:8080");
Expand Down Expand Up @@ -57,7 +59,8 @@ function testAddingValuesToVectorStore() returns error? {
'type: "text",
content: "This is a test chunk",
metadata: {
createdAt
createdAt,
fileName
}
}
}
Expand Down Expand Up @@ -98,9 +101,9 @@ function testQueryValuesFromVectorStore() returns error? {
filters: {
filters: [
{
'key: "createdAt",
'key: "fileName",
operator: ai:EQUAL,
value: createdAt
value: fileName
},
{
'key: "content",
Expand Down
43 changes: 43 additions & 0 deletions ballerina/utils.bal
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ballerina/ai;
import ballerina/time;
import ballerinax/weaviate;

# Converts metadata filters to Weaviate compatible filter format
#
Expand Down Expand Up @@ -156,3 +157,45 @@ isolated function mapToGraphQLObjectString(map<anydata> filter) returns string {
result += "}";
return result;
}

isolated function getCollectionProperties(string className,
weaviate:Client weaviateClient) returns string[]|ai:Error {
lock {
string introspectionQuery = string `{
__type(name: "${className}") {
name
fields {
name
type {
name
kind
}
}
}
}`;
weaviate:GraphQLResponse result = check weaviateClient->/graphql.post({
query: introspectionQuery
});
weaviate:GraphQLError[]? errorResult = result?.errors;
if errorResult !is () {
return error("Failed to get collection properties: " + errorResult.toJsonString());
}
record {|weaviate:JsonObject...;|}? response = result.data;
if response is () {
return error("No data returned from GraphQL introspection query");
}
map<json> typeMap = check response.get("__type").cloneWithType();
json fieldsData = typeMap.get("fields");
map<json>[] fieldsArray = check fieldsData.cloneWithType();
string[] propertyNames = [];
foreach map<json> fieldItem in fieldsArray {
json nameData = fieldItem.get("name");
if nameData is string && nameData != "_additional" {
propertyNames.push(nameData);
}
}
return propertyNames;
} on fail error err {
return error("Failed to get collection properties", err);
}
}
25 changes: 17 additions & 8 deletions ballerina/vector_store.bal
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ public isolated class VectorStore {
filterSection = "where: " + mapToGraphQLObjectString(weaviateFilter);
}
}
string metadataFieldsString = "";
foreach string fieldName in metadataFields {
metadataFieldsString += fieldName + "\n ";
string[] allProperties = check getCollectionProperties(self.config.collectionName, self.weaviateClient);
string allFieldsString = "";
foreach string fieldName in allProperties {
allFieldsString += fieldName + "\n";
}
string gqlQuery = string `{
Get {
Expand All @@ -163,8 +164,7 @@ public isolated class VectorStore {
}` : string ``
}
) {
content
${metadataFieldsString}
${allFieldsString}
_additional {
certainty
id
Expand Down Expand Up @@ -193,9 +193,18 @@ public isolated class VectorStore {
ai:VectorMatch[] matches = [];
foreach weaviate:JsonObject element in value {
ai:Metadata metadata = {};
foreach string fieldName in metadataFields {
time:Utc|error metadataValue = time:utcFromString(element.get(fieldName).toString());
metadata[fieldName] = metadataValue is error ? element.get(fieldName).toString() : metadataValue;
foreach string fieldName in allProperties {
if element.hasKey(fieldName) {
anydata fieldValue = element.get(fieldName);
if fieldValue is string {
time:Utc|error metadataValue = time:utcFromString(fieldValue);
metadata[fieldName] = metadataValue is error ? fieldValue : metadataValue;
} else if fieldValue is int|float|decimal|boolean {
metadata[fieldName] = fieldValue;
} else {
metadata[fieldName] = fieldValue.toString();
}
}
}
ai:TextChunk chunk = {
content: element.content.toString(),
Expand Down
Loading