diff --git a/ballerina/tests/test.bal b/ballerina/tests/test.bal index 1051dbf..989ecc1 100644 --- a/ballerina/tests/test.bal +++ b/ballerina/tests/test.bal @@ -16,17 +16,36 @@ import ballerina/ai; import ballerina/test; +import ballerina/time; import ballerina/uuid; +import ballerina/http; final VectorStore mockVectorStore = check new ( - serviceUrl = "http://localhost:8080/v1", + serviceUrl = "http://localhost:8080", config = { - collectionName: "Chunk" + collectionName: "Test" }, apiKey = "mock-token" ); +@test:BeforeSuite +function beforeSuite() returns error? { + http:Client httpClient = check new ("http://localhost:8080"); + http:Response _ = check httpClient->post(path = "/v1/schema", headers = { + "Content-Type": "application/json" + }, message = { + "class": "Test", + "properties": [ + { "name": "content", "dataType": ["text"] }, + { "name": "type", "dataType": ["string"] }, + { "name": "createdAt", "dataType": ["date"] } + ] + }); +} + + string id = uuid:createRandomUuid(); +time:Utc createdAt = time:utcNow(); @test:Config {} function testAddingValuesToVectorStore() returns error? { @@ -35,8 +54,11 @@ function testAddingValuesToVectorStore() returns error? { id, embedding: [1.0, 2.0, 3.0], chunk: { - 'type: "text", - content: "This is a test chunk" + 'type: "text", + content: "This is a test chunk", + metadata: { + createdAt + } } } ]; @@ -58,7 +80,7 @@ function testDeleteMultipleValuesFromVectorStore() returns error? { id: index, embedding: [1.0, 2.0, 3.0], chunk: { - 'type: "text", + 'type: "text", content: "This is a test chunk" } } @@ -68,21 +90,24 @@ function testDeleteMultipleValuesFromVectorStore() returns error? { test:assertTrue(result !is error); } -@test:Config {} +@test:Config { + dependsOn: [testAddingValuesToVectorStore] +} function testQueryValuesFromVectorStore() returns error? { ai:VectorStoreQuery query = { filters: { filters: [ { + 'key: "createdAt", operator: ai:EQUAL, - 'key: "content", - value: "This is a test chunk" + value: createdAt } ] } }; - ai:VectorMatch[]|ai:Error result = mockVectorStore.query(query); - test:assertTrue(result !is error); + ai:VectorMatch[] result = check mockVectorStore.query(query); + test:assertTrue(result.length() > 0); + test:assertEquals(result[0].chunk.metadata?.createdAt, createdAt); } @test:Config {} diff --git a/ballerina/utils.bal b/ballerina/utils.bal index dd3f2c9..a8cbf9c 100644 --- a/ballerina/utils.bal +++ b/ballerina/utils.bal @@ -15,12 +15,14 @@ // under the License. import ballerina/ai; +import ballerina/time; # Converts metadata filters to Weaviate compatible filter format # # + filters - The metadata filters containing filter conditions and logical operators +# + metadataFields - The fields of the metadata to be filtered # + return - A map representing the converted filter structure or an error if conversion fails -isolated function convertWeaviateFilters(ai:MetadataFilters filters) returns map|ai:Error { +isolated function convertWeaviateFilters(ai:MetadataFilters filters, string[] metadataFields) returns map|ai:Error { (ai:MetadataFilters|ai:MetadataFilter)[]? rawFilters = filters.filters; if rawFilters == () || rawFilters.length() == 0 { return {}; @@ -28,15 +30,21 @@ isolated function convertWeaviateFilters(ai:MetadataFilters filters) returns map map[] filterList = []; foreach (ai:MetadataFilters|ai:MetadataFilter) filter in rawFilters { if filter is ai:MetadataFilter { + metadataFields.push(filter.key); map filterMap = {}; string weaviateOp = check mapWeaviateOperator(filter.operator); filterMap["path"] = [filter.key]; filterMap["operator"] = weaviateOp; - filterMap["valueText"] = filter.value; + json value = filter.value; + if value is time:Utc { + filterMap["valueDate"] = string `"${time:utcToString(value)}"`; + } else { + filterMap["valueText"] = value; + } filterList.push(filterMap); continue; } - map nestedFilter = check convertWeaviateFilters(filter); + map nestedFilter = check convertWeaviateFilters(filter, metadataFields); if nestedFilter.length() > 0 { filterList.push(nestedFilter); } @@ -138,7 +146,7 @@ isolated function mapToGraphQLObjectString(map filter) returns string { } result += "[" + resultArr + "]"; } else if value is string { - result += 'key == "operator" ? value : string `"${value}"`; + result += 'key == "operator" ? value : string `${value}`; } else { result += value.toString(); } diff --git a/ballerina/vector_store.bal b/ballerina/vector_store.bal index a93fade..339badc 100644 --- a/ballerina/vector_store.bal +++ b/ballerina/vector_store.bal @@ -17,6 +17,7 @@ import ballerina/ai; import ballerina/http; import ballerinax/weaviate; +import ballerina/time; # Weaviate Vector Store implementation with support for Dense, Sparse, and Hybrid vector search modes. # @@ -47,7 +48,7 @@ public isolated class VectorStore { token: apiKey }; do { - self.weaviateClient = check new (check httpConfig.cloneWithType(), serviceUrl); + self.weaviateClient = check new (check httpConfig.cloneWithType(), string `${serviceUrl}/v1` ); } on fail error err { return error("Failed to initialize weaviate vector store", err); } @@ -70,11 +71,21 @@ public isolated class VectorStore { weaviate:Object[] objects = []; foreach ai:VectorEntry entry in entries.cloneReadOnly() { ai:Embedding embedding = entry.embedding; - weaviate:PropertySchema properties = entry.chunk.metadata !is () ? - check entry.chunk.metadata.cloneWithType() : {}; + weaviate:PropertySchema properties = {}; properties[self.chunkFieldName] = entry.chunk.content; properties["type"] = entry.chunk.'type; - + ai:Metadata? metadata = entry.chunk.metadata; + if metadata !is () { + foreach string item in metadata.keys() { + anydata metadataValue = metadata.get(item); + if metadataValue is time:Utc { + string utcToString = time:utcToString(metadataValue); + properties[item] = utcToString; + } else { + properties[item] = metadataValue; + } + } + } if embedding is ai:Vector { objects.push({ 'class: self.config.collectionName, @@ -129,13 +140,18 @@ public isolated class VectorStore { return error("Invalid value for topK. The value cannot be 0 or less than -1."); } string filterSection = ""; + string[] metadataFields = []; if query.hasKey("filters") && query.filters is ai:MetadataFilters { ai:MetadataFilters? filters = query.cloneReadOnly().filters; if filters !is () { - map weaviateFilter = check convertWeaviateFilters(filters); + map weaviateFilter = check convertWeaviateFilters(filters, metadataFields); filterSection = "where: " + mapToGraphQLObjectString(weaviateFilter); } } + string metadataFieldsString = ""; + foreach string fieldName in metadataFields { + metadataFieldsString += fieldName + "\n "; + } string gqlQuery = string `{ Get { ${self.config.collectionName}( @@ -148,6 +164,7 @@ public isolated class VectorStore { } ) { content + ${metadataFieldsString} _additional { certainty id @@ -175,13 +192,19 @@ public isolated class VectorStore { QueryResult[] value = check data.cloneWithType(); ai:VectorMatch[] matches = []; foreach weaviate:JsonObject element in value { + ai:Metadata metadata = {}; + foreach string fieldName in metadataFields { + time:Utc|error metadataValue = time:utcFromString(element.get(fieldName).toString()); + metadata[fieldName] = metadataValue is error ? element.get(fieldName).toString() : metadataValue; + } + ai:TextChunk chunk = { + content: element.content.toString(), + metadata: check metadata.cloneWithType() + }; matches.push({ id: element._additional.id, embedding: element._additional.vector, - chunk: { - 'type: element.'type is () ? "" : check element.'type.cloneWithType(), - content: element.content - }, + chunk, similarityScore: element._additional.certainty !is () ? check element._additional.certainty.cloneWithType() : 0.0 }); diff --git a/resources/server/compose.yml b/resources/server/compose.yml index 815eab3..02614db 100644 --- a/resources/server/compose.yml +++ b/resources/server/compose.yml @@ -16,7 +16,7 @@ services: restart: on-failure:0 environment: QUERY_DEFAULTS_LIMIT: 25 - AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'false' + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' AUTHENTICATION_APIKEY_ENABLED: 'true' AUTHENTICATION_APIKEY_ALLOWED_KEYS: 'mock-token' AUTHENTICATION_APIKEY_USERS: 'test-user'