diff --git a/ballerina/Ballerina.toml b/ballerina/Ballerina.toml index 8e39ef5..b47dc35 100644 --- a/ballerina/Ballerina.toml +++ b/ballerina/Ballerina.toml @@ -2,7 +2,7 @@ distribution = "2201.12.0" org = "ballerinax" name = "ai.memory.mssql" -version = "1.0.0" +version = "1.0.1" license = ["Apache-2.0"] authors = ["Ballerina"] keywords = ["ai", "agent", "memory"] diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index eef2cc8..0244604 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -10,7 +10,7 @@ distribution-version = "2201.12.0" [[package]] org = "ballerina" name = "ai" -version = "1.6.0" +version = "1.7.0" dependencies = [ {org = "ballerina", name = "constraint"}, {org = "ballerina", name = "data.jsondata"}, @@ -26,6 +26,7 @@ dependencies = [ {org = "ballerina", name = "math.vector"}, {org = "ballerina", name = "mcp"}, {org = "ballerina", name = "mime"}, + {org = "ballerina", name = "observe"}, {org = "ballerina", name = "time"}, {org = "ballerina", name = "url"}, {org = "ballerina", name = "uuid"}, @@ -33,7 +34,8 @@ dependencies = [ ] modules = [ {org = "ballerina", packageName = "ai", moduleName = "ai"}, - {org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"} + {org = "ballerina", packageName = "ai", moduleName = "ai.intelligence"}, + {org = "ballerina", packageName = "ai", moduleName = "ai.observe"} ] [[package]] @@ -268,7 +270,7 @@ version = "1.2.0" [[package]] org = "ballerina" name = "mcp" -version = "1.0.1" +version = "1.0.2" dependencies = [ {org = "ballerina", name = "http"}, {org = "ballerina", name = "jballerina.java"}, @@ -302,7 +304,7 @@ dependencies = [ [[package]] org = "ballerina" name = "observe" -version = "1.5.1" +version = "1.6.0" dependencies = [ {org = "ballerina", name = "jballerina.java"} ] @@ -357,7 +359,7 @@ modules = [ [[package]] org = "ballerina" name = "time" -version = "2.7.0" +version = "2.8.0" dependencies = [ {org = "ballerina", name = "jballerina.java"} ] @@ -365,7 +367,7 @@ dependencies = [ [[package]] org = "ballerina" name = "url" -version = "2.6.0" +version = "2.6.1" dependencies = [ {org = "ballerina", name = "jballerina.java"} ] @@ -405,7 +407,7 @@ dependencies = [ [[package]] org = "ballerinax" name = "ai.memory.mssql" -version = "1.0.0" +version = "1.0.1" dependencies = [ {org = "ballerina", name = "ai"}, {org = "ballerina", name = "cache"}, diff --git a/ballerina/store.bal b/ballerina/store.bal index dd3f0e2..898f7b0 100644 --- a/ballerina/store.bal +++ b/ballerina/store.bal @@ -55,7 +55,7 @@ public isolated class ShortTermMemoryStore { *ai:ShortTermMemoryStore; private final mssql:Client dbClient; - private final cache:Cache cache; + private final cache:Cache? cache; private final int maxMessagesPerKey; # Initializes the MS SQL-backed short-term memory store. @@ -66,7 +66,7 @@ public isolated class ShortTermMemoryStore { # + returns - An error if the initialization fails public isolated function init(mssql:Client|DatabaseConfiguration mssqlClient, int maxMessagesPerKey = 20, - cache:CacheConfig cacheConfig = {capacity: 20}) returns Error? { + cache:CacheConfig? cacheConfig = ()) returns Error? { if mssqlClient is mssql:Client { self.dbClient = mssqlClient; } else { @@ -77,7 +77,7 @@ public isolated class ShortTermMemoryStore { self.dbClient = initializedClient; } self.maxMessagesPerKey = maxMessagesPerKey; - self.cache = new (cacheConfig); + self.cache = cacheConfig is () ? () : new (cacheConfig); return self.initializeDatabase(); } @@ -396,8 +396,9 @@ public isolated class ShortTermMemoryStore { final ai:ChatInteractiveMessage[] & readonly immutableInteractiveMessages = interactiveMessages.cloneReadOnly(); lock { - if !self.cache.hasKey(key) { - check self.cache.put( + cache:Cache? cache = self.cache; + if cache !is () && !cache.hasKey(key) { + check cache.put( key, {systemMessage, interactiveMessages: [...immutableInteractiveMessages]}); } } @@ -413,8 +414,9 @@ public isolated class ShortTermMemoryStore { private isolated function removeCacheEntry(string key) { lock { - if self.cache.hasKey(key) { - cache:Error? err = self.cache.invalidate(key); + cache:Cache? cache = self.cache; + if cache !is () && cache.hasKey(key) { + cache:Error? err = cache.invalidate(key); if err is cache:Error { // Ignore, as this is for non-existent key } @@ -424,11 +426,12 @@ public isolated class ShortTermMemoryStore { private isolated function getCacheEntry(string key) returns CachedMessages? { lock { - if !self.cache.hasKey(key) { + cache:Cache? cache = self.cache; + if cache is () || !cache.hasKey(key) { return (); } - any|cache:Error cacheEntry = self.cache.get(key); + any|cache:Error cacheEntry = cache.get(key); if cacheEntry is cache:Error { return (); } diff --git a/ballerina/tests/memory_store_test.bal b/ballerina/tests/memory_store_test.bal index e03888a..616b455 100644 --- a/ballerina/tests/memory_store_test.bal +++ b/ballerina/tests/memory_store_test.bal @@ -15,12 +15,14 @@ // under the License. import ballerina/ai; +import ballerina/cache; import ballerina/sql; import ballerina/test; import ballerinax/mssql; const string K1 = "key1"; const string K2 = "key2"; +const string K3 = "key3"; const ai:ChatSystemMessage K1SM1 = {role: ai:SYSTEM, content: "You are a helpful assistant that is aware of the weather."}; @@ -549,3 +551,282 @@ isolated function assertContentEquals(ai:Prompt|string actual, ai:Prompt|string test:assertFail("Actual and expected content do not match"); } + +@test:Config { + before: dropTable +} +function testBasicStoreWithCache() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + check store.put(K1, k1m2); + check store.put(K2, K2M1); + + // First retrieval - should load from database and cache + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2]); + check assertInteractiveMessages(store, K1, [K1M1, k1m2]); + + // Second retrieval - should use cache (verify by checking results still match) + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2]); + check assertInteractiveMessages(store, K1, [K1M1, k1m2]); + + check assertAllMessages(store, K2, [K2M1]); + check assertInteractiveMessages(store, K2, [K2M1]); +} + +@test:Config { + before: dropTable +} +function testCacheUpdateOnPut() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + + // Load into cache + check assertAllMessages(store, K1, [K1SM1, K1M1]); + + // Add more messages - cache should be updated + check store.put(K1, k1m2); + check store.put(K1, K1M3); + + // Verify cache reflects the updates + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2, K1M3]); + check assertInteractiveMessages(store, K1, [K1M1, k1m2, K1M3]); +} + +@test:Config { + before: dropTable +} +function testCacheSystemMessageUpdate() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + + // Load into cache + check assertSystemMessage(store, K1, K1SM1); + check assertAllMessages(store, K1, [K1SM1, K1M1]); + + // Update system message + final readonly & ai:ChatSystemMessage k1sm2 = { + role: ai:SYSTEM, + content: "You are a helpful assistant that is aware of sports." + }; + check store.put(K1, k1sm2); + + // Verify cache reflects the system message update + check assertSystemMessage(store, K1, k1sm2); + check assertAllMessages(store, K1, [k1sm2, K1M1]); +} + +@test:Config { + before: dropTable +} +function testCacheInvalidationOnRemoveAll() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + check store.put(K1, k1m2); + + // Load into cache + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2]); + + // Remove all messages + check store.removeAll(K1); + + // Verify cache is invalidated and returns empty + check assertAllMessages(store, K1, []); + check assertSystemMessage(store, K1, ()); + check assertInteractiveMessages(store, K1, []); +} + +@test:Config { + before: dropTable +} +function testCacheInvalidationOnRemoveInteractiveMessages() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + check store.put(K1, k1m2); + check store.put(K1, K1M3); + + // Load into cache + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2, K1M3]); + + // Remove all interactive messages + check store.removeChatInteractiveMessages(K1); + + // Verify cache reflects the removal + check assertAllMessages(store, K1, [K1SM1]); + check assertSystemMessage(store, K1, K1SM1); + check assertInteractiveMessages(store, K1, []); +} + +@test:Config { + before: dropTable +} +function testCacheInvalidationOnRemoveSubsetOfInteractiveMessages() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + check store.put(K1, k1m2); + check store.put(K1, K1M3); + check store.put(K1, K1M4); + + // Load into cache + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2, K1M3, K1M4]); + + // Remove first 2 interactive messages + check store.removeChatInteractiveMessages(K1, 2); + + // Verify cache reflects the partial removal + check assertAllMessages(store, K1, [K1SM1, K1M3, K1M4]); + check assertInteractiveMessages(store, K1, [K1M3, K1M4]); +} + +@test:Config { + before: dropTable +} +function testCacheUpdateOnRemoveSystemMessage() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + check store.put(K1, k1m2); + + // Load into cache + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2]); + check assertSystemMessage(store, K1, K1SM1); + + // Remove system message + check store.removeChatSystemMessage(K1); + + // Verify cache reflects the system message removal + check assertAllMessages(store, K1, [K1M1, k1m2]); + check assertSystemMessage(store, K1, ()); + check assertInteractiveMessages(store, K1, [K1M1, k1m2]); +} + +@test:Config { + before: dropTable +} +function testCacheWithMultipleKeys() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + // Add messages for K1 + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + check store.put(K1, k1m2); + + // Add messages for K2 + check store.put(K2, K2M1); + + // Load both into cache + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2]); + check assertAllMessages(store, K2, [K2M1]); + + // Remove K1 + check store.removeAll(K1); + + // Verify K1 is cleared but K2 is still in cache + check assertAllMessages(store, K1, []); + check assertAllMessages(store, K2, [K2M1]); +} + +@test:Config { + before: dropTable +} +function testCacheWithSmallCapacity() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 2, + evictionFactor: 0.5 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1M1); + check store.put(K2, K2M1); + check store.put(K3, K1M3); + + // Load K1 and K2 into cache + check assertAllMessages(store, K1, [K1M1]); + check assertAllMessages(store, K2, [K2M1]); + + // Load K3 - may evict older entries due to capacity + check assertAllMessages(store, K3, [K1M3]); + + // All keys should still be retrievable (from cache or database) + check assertAllMessages(store, K1, [K1M1]); + check assertAllMessages(store, K2, [K2M1]); + check assertAllMessages(store, K3, [K1M3]); +} + +@test:Config { + before: dropTable +} +function testSystemMessageRetrievalDoesNotPopulateCache() returns error? { + mssql:Client cl = getClient(); + cache:CacheConfig cacheConfig = { + capacity: 10, + evictionFactor: 0.2 + }; + ShortTermMemoryStore store = check new (cl, cacheConfig = cacheConfig); + + check store.put(K1, K1SM1); + check store.put(K1, K1M1); + check store.put(K1, k1m2); + + // Retrieve only system message - should NOT populate cache + check assertSystemMessage(store, K1, K1SM1); + + // Add more messages + check store.put(K1, K1M3); + + // Retrieve all messages - should load from database and include K1M3 + check assertAllMessages(store, K1, [K1SM1, K1M1, k1m2, K1M3]); +}