Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions ballerina/tests/test.bal
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@

import ballerina/ai;
import ballerina/test;
import ballerina/time;
import ballerina/uuid;

final VectorStore mockVectorStore = check new (
serviceUrl = "http://localhost:8080/v1",
serviceUrl = "http://localhost:8080",
config = {
collectionName: "Chunk"
},
apiKey = "mock-token"
);

string id = uuid:createRandomUuid();
time:Utc createdAt = time:utcNow();

@test:Config {}
function testAddingValuesToVectorStore() returns error? {
Expand All @@ -35,8 +37,11 @@ function testAddingValuesToVectorStore() returns error? {
id,
embedding: [1.0, 2.0, 3.0],
chunk: {
'type: "text",
content: "This is a test chunk"
'type: "text",
content: "This is a test chunk",
metadata: {
createdAt
}
}
}
];
Expand All @@ -58,7 +63,7 @@ function testDeleteMultipleValuesFromVectorStore() returns error? {
id: index,
embedding: [1.0, 2.0, 3.0],
chunk: {
'type: "text",
'type: "text",
content: "This is a test chunk"
}
}
Expand All @@ -68,21 +73,24 @@ function testDeleteMultipleValuesFromVectorStore() returns error? {
test:assertTrue(result !is error);
}

@test:Config {}
@test:Config {
dependsOn: [testAddingValuesToVectorStore]
}
function testQueryValuesFromVectorStore() returns error? {
ai:VectorStoreQuery query = {
filters: {
filters: [
{
'key: "createdAt",
operator: ai:EQUAL,
'key: "content",
value: "This is a test chunk"
value: createdAt
}
]
}
};
ai:VectorMatch[]|ai:Error result = mockVectorStore.query(query);
test:assertTrue(result !is error);
ai:VectorMatch[] result = check mockVectorStore.query(query);
test:assertTrue(result.length() > 0);
test:assertEquals(result[0].chunk.metadata?.createdAt, createdAt);
}

@test:Config {}
Expand Down
16 changes: 12 additions & 4 deletions ballerina/utils.bal
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,36 @@
// under the License.

import ballerina/ai;
import ballerina/time;

# Converts metadata filters to Weaviate compatible filter format
#
# + filters - The metadata filters containing filter conditions and logical operators
# + metadataFields - The fields of the metadata to be filtered
# + return - A map representing the converted filter structure or an error if conversion fails
isolated function convertWeaviateFilters(ai:MetadataFilters filters) returns map<anydata>|ai:Error {
isolated function convertWeaviateFilters(ai:MetadataFilters filters, string[] metadataFields) returns map<anydata>|ai:Error {
(ai:MetadataFilters|ai:MetadataFilter)[]? rawFilters = filters.filters;
if rawFilters == () || rawFilters.length() == 0 {
return {};
}
map<anydata>[] filterList = [];
foreach (ai:MetadataFilters|ai:MetadataFilter) filter in rawFilters {
if filter is ai:MetadataFilter {
metadataFields.push(filter.key);
map<anydata> filterMap = {};
string weaviateOp = check mapWeaviateOperator(filter.operator);
filterMap["path"] = [filter.key];
filterMap["operator"] = weaviateOp;
filterMap["valueText"] = filter.value;
json value = filter.value;
if value is time:Utc {
filterMap["valueDate"] = string `"${time:utcToString(value)}"`;
} else {
filterMap["valueText"] = value;
}
filterList.push(filterMap);
continue;
}
map<anydata> nestedFilter = check convertWeaviateFilters(filter);
map<anydata> nestedFilter = check convertWeaviateFilters(filter, metadataFields);
if nestedFilter.length() > 0 {
filterList.push(nestedFilter);
}
Expand Down Expand Up @@ -138,7 +146,7 @@ isolated function mapToGraphQLObjectString(map<anydata> filter) returns string {
}
result += "[" + resultArr + "]";
} else if value is string {
result += 'key == "operator" ? value : string `"${value}"`;
result += 'key == "operator" ? value : string `${value}`;
} else {
result += value.toString();
}
Expand Down
41 changes: 32 additions & 9 deletions ballerina/vector_store.bal
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ballerina/ai;
import ballerina/http;
import ballerinax/weaviate;
import ballerina/time;

# Weaviate Vector Store implementation with support for Dense, Sparse, and Hybrid vector search modes.
#
Expand Down Expand Up @@ -47,7 +48,7 @@ public isolated class VectorStore {
token: apiKey
};
do {
self.weaviateClient = check new (check httpConfig.cloneWithType(), serviceUrl);
self.weaviateClient = check new (check httpConfig.cloneWithType(), string `${serviceUrl}/v1` );
} on fail error err {
return error("Failed to initialize weaviate vector store", err);
}
Expand All @@ -70,11 +71,21 @@ public isolated class VectorStore {
weaviate:Object[] objects = [];
foreach ai:VectorEntry entry in entries.cloneReadOnly() {
ai:Embedding embedding = entry.embedding;
weaviate:PropertySchema properties = entry.chunk.metadata !is () ?
check entry.chunk.metadata.cloneWithType() : {};
weaviate:PropertySchema properties = {};
properties[self.chunkFieldName] = entry.chunk.content;
properties["type"] = entry.chunk.'type;

ai:Metadata? metadata = entry.chunk.metadata;
if metadata !is () {
foreach string item in metadata.keys() {
anydata metadataValue = metadata.get(item);
if metadataValue is time:Utc {
string utcToString = time:utcToString(metadataValue);
properties[item] = utcToString;
} else {
properties[item] = metadataValue;
}
}
}
if embedding is ai:Vector {
objects.push({
'class: self.config.collectionName,
Expand Down Expand Up @@ -129,13 +140,18 @@ public isolated class VectorStore {
return error("Invalid value for topK. The value cannot be 0 or less than -1.");
}
string filterSection = "";
string[] metadataFields = [];
if query.hasKey("filters") && query.filters is ai:MetadataFilters {
ai:MetadataFilters? filters = query.cloneReadOnly().filters;
if filters !is () {
map<anydata> weaviateFilter = check convertWeaviateFilters(filters);
map<anydata> weaviateFilter = check convertWeaviateFilters(filters, metadataFields);
filterSection = "where: " + mapToGraphQLObjectString(weaviateFilter);
}
}
string metadataFieldsString = "";
foreach string fieldName in metadataFields {
metadataFieldsString += fieldName + "\n ";
}
string gqlQuery = string `{
Get {
${self.config.collectionName}(
Expand All @@ -148,6 +164,7 @@ public isolated class VectorStore {
}
) {
content
${metadataFieldsString}
_additional {
certainty
id
Expand Down Expand Up @@ -175,13 +192,19 @@ public isolated class VectorStore {
QueryResult[] value = check data.cloneWithType();
ai:VectorMatch[] matches = [];
foreach weaviate:JsonObject element in value {
ai:Metadata metadata = {};
foreach string fieldName in metadataFields {
time:Utc|error metadataValue = time:utcFromString(element.get(fieldName).toString());
metadata[fieldName] = metadataValue is error ? element.get(fieldName).toString() : metadataValue;
}
ai:TextChunk chunk = {
content: element.content.toString(),
metadata: check metadata.cloneWithType()
};
matches.push({
id: element._additional.id,
embedding: element._additional.vector,
chunk: {
'type: element.'type is () ? "" : check element.'type.cloneWithType(),
content: element.content
},
chunk,
similarityScore: element._additional.certainty !is () ?
check element._additional.certainty.cloneWithType() : 0.0
});
Expand Down
Loading