Skip to content
Merged
2 changes: 1 addition & 1 deletion ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
distribution = "2201.12.0"
org = "ballerinax"
name = "ai.memory.mssql"
version = "1.1.0"
version = "1.2.0"
license = ["Apache-2.0"]
authors = ["Ballerina"]
keywords = ["ai", "agent", "memory"]
Expand Down
10 changes: 5 additions & 5 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ distribution-version = "2201.12.0"
[[package]]
org = "ballerina"
name = "ai"
version = "1.7.0"
version = "1.9.0"
dependencies = [
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "data.jsondata"},
Expand Down Expand Up @@ -257,7 +257,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "log"
version = "2.14.0"
version = "2.12.0"
dependencies = [
{org = "ballerina", name = "io"},
{org = "ballerina", name = "jballerina.java"},
Expand Down Expand Up @@ -338,7 +338,7 @@ modules = [
[[package]]
org = "ballerina"
name = "task"
version = "2.11.0"
version = "2.10.0"
dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "time"},
Expand All @@ -362,7 +362,7 @@ modules = [
[[package]]
org = "ballerina"
name = "time"
version = "2.8.0"
version = "2.7.0"
dependencies = [
{org = "ballerina", name = "jballerina.java"}
]
Expand Down Expand Up @@ -410,7 +410,7 @@ dependencies = [
[[package]]
org = "ballerinax"
name = "ai.memory.mssql"
version = "1.1.0"
version = "1.2.0"
dependencies = [
{org = "ballerina", name = "ai"},
{org = "ballerina", name = "cache"},
Expand Down
148 changes: 115 additions & 33 deletions ballerina/store.bal
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,15 @@ public isolated class ShortTermMemoryStore {
}
}

# Adds a chat message to the memory store for a given key.
# Adds one or more chat messages to the memory store for a given key.
#
# + key - The key associated with the memory
# + message - The `ChatMessage` message to store
# + message - The `ChatMessage` message or messages to store
# + return - nil on success, or an `Error` if the operation fails
public isolated function put(string key, ai:ChatMessage message) returns Error? {
public isolated function put(string key, ai:ChatMessage|ai:ChatMessage[] message) returns Error? {
if message is ai:ChatMessage[] {
return self.putAll(key, message);
}
ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(message);

if dbMessage is ChatSystemMessageDatabaseMessage {
Expand All @@ -209,18 +212,6 @@ public isolated class ShortTermMemoryStore {
}
} else {
do {
// Expected to be checked by the caller, but doing an additional check here, without locking.
ai:ChatInteractiveMessage[]|Error chatInteractiveMessages = self.getChatInteractiveMessages(key);
if chatInteractiveMessages is Error {
error? cause = chatInteractiveMessages.cause();
fail cause is error ? cause : chatInteractiveMessages;
}

if chatInteractiveMessages.length() >= self.maxMessagesPerKey {
return error(string `Cannot add more messages. Maximum limit of '${
self.maxMessagesPerKey}' reached for key: '${key}'`);
}

_ = check self.dbClient->execute(
replaceTableNamePlaceholder(`
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
Expand All @@ -245,6 +236,108 @@ public isolated class ShortTermMemoryStore {
cacheEntry.interactiveMessages.push(immutableMessage);
}
}
return;
}

private isolated function putAll(string key, ai:ChatMessage[] messages) returns Error? {
if messages.length() == 0 {
return;
}

// Separate system and interactive messages
final ai:ChatSystemMessage[] systemMsgs = [];
final ai:ChatInteractiveMessage[] interactiveMsgs = [];

foreach ai:ChatMessage msg in messages {
if msg is ai:ChatSystemMessage {
systemMsgs.push(msg);
} else if msg is ai:ChatInteractiveMessage {
interactiveMsgs.push(msg);
}
}

if systemMsgs.length() > 0 {
// Only the last system message is used
ai:ChatSystemMessage lastSystem = systemMsgs[systemMsgs.length() - 1];
ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(lastSystem);

sql:ExecutionResult|sql:Error upsertResult = self.dbClient->execute(
replaceTableNamePlaceholder(`
IF EXISTS (SELECT 1 FROM $_tableName_$ WHERE MessageKey = ${key} AND MessageRole = 'system')
UPDATE $_tableName_$ SET MessageJson = ${dbMessage.toJsonString()}
WHERE MessageKey = ${key} AND MessageRole = 'system'
ELSE
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
VALUES (${key}, ${dbMessage.role}, ${dbMessage.toJsonString()})`,
self.tableName
)
);
if upsertResult is sql:Error {
return error("Failed to upsert system message: " + upsertResult.message(), upsertResult);
}
}

// Insert interactive messages in batch
if interactiveMsgs.length() > 0 {

// Fetch current interactive count
ai:ChatInteractiveMessage[] chatMsgs = check self.getChatInteractiveMessages(key);
int currentCount = chatMsgs.length();
int incoming = interactiveMsgs.length();

if currentCount + incoming > self.maxMessagesPerKey {
return error(string `Cannot add more messages.`
+ string ` Maximum limit '${self.maxMessagesPerKey}' exceeded for key '${key}'`);
}

sql:ParameterizedQuery[] insertQueries = from ai:ChatInteractiveMessage msg in interactiveMsgs
let ChatMessageDatabaseMessage dbMsg = transformToDatabaseMessage(msg)
select replaceTableNamePlaceholder(`
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
VALUES (${key}, ${msg.role}, ${dbMsg.toJsonString()})`,
self.tableName
);

sql:ExecutionResult[]|sql:Error batchResult = self.dbClient->batchExecute(insertQueries);
if batchResult is sql:Error {
return error("Failed batch insert of interactive messages: " + batchResult.message(), batchResult);
}
}

final readonly & ai:ChatSystemMessage? chatSystemMessage;
if systemMsgs.length() > 0 {
ai:ChatSystemMessage lastSystem = systemMsgs[systemMsgs.length() - 1];
final readonly & ai:ChatMessage immutableMessage = mapToImmutableMessage(lastSystem);
if immutableMessage is ai:ChatSystemMessage {
chatSystemMessage = immutableMessage;
} else {
chatSystemMessage = ();
}
} else {
chatSystemMessage = ();
}

final ai:ChatMessage[] & readonly immutableMessages = interactiveMsgs.'map(msg => mapToImmutableMessage(msg))
.cloneReadOnly();

// Update cache
lock {
CachedMessages? cacheEntry = self.getCacheEntry(key);
if cacheEntry is () {
return;
}

if chatSystemMessage is ai:ChatSystemMessage {
cacheEntry.systemMessage = chatSystemMessage;
}

foreach ai:ChatMessage msg in immutableMessages {
if msg is ai:ChatInteractiveMessage {
cacheEntry.interactiveMessages.push(msg);
}
}
}
return;
}

# Removes the system chat message, if specified, for a given key.
Expand Down Expand Up @@ -401,24 +494,6 @@ public isolated class ShortTermMemoryStore {

private isolated function cacheFromDatabase(string key)
returns readonly & ([ai:ChatSystemMessage, ai:ChatInteractiveMessage...]|ai:ChatInteractiveMessage[])|Error {
int|sql:Error messageCount = self.dbClient->queryRow(
replaceTableNamePlaceholder(`
SELECT COUNT(*) as count
FROM $_tableName_$
WHERE MessageKey = ${key} AND MessageRole != 'system'`,
self.tableName
)
);

if messageCount is sql:Error {
return error("Failed to load message count from the database: " + messageCount.message(), messageCount);
}

if messageCount > self.maxMessagesPerKey {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need this still?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do't need to do that. It should be handled by the caller as discussed previously with @shafreenAnfar and @SasinduDilshara.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For updates, yes. But this gets called during a get too, right? I'm not sure we can expect users to do this check for gets also.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if updates are already handled by the user, that guarantees the memory will always have a message count below the configured size. I do not think we need to handle this in get.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't happen for the first get that happens if no update has happened, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline with @MaryamZi, there can be cases where a memory instance is created with a specific size configuration, for example 10, and updates are persisted accordingly, so the database contains 10 entries. Later, if the user code is changed to use a smaller size, the database may still hold more entries than the new limit, which could result in an error.

Since this is a very rare scenario, we are currently removing the size check to reduce database calls.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the cache is used, there'll be less database calls anyway, so one here would be okay IMO.

return error ExceedsSizeError(string `Cannot load messages from the database: Message count '${
messageCount}' exceeds maximum limit of '${self.maxMessagesPerKey}' for key: '${key}'`);
}

do {
stream<DatabaseRecord, sql:Error?> messages = self.dbClient->query(
replaceTableNamePlaceholder(`
Expand Down Expand Up @@ -493,6 +568,13 @@ public isolated class ShortTermMemoryStore {
return checkpanic cacheEntry.ensureType();
}
}

# Retrieves the maximum number of interactive messages that can be stored for each key.
#
# + return - The configured capacity of the message store per key
public isolated function getCapacity() returns int {
return self.maxMessagesPerKey;
}
}

isolated function replaceTableNamePlaceholder(sql:ParameterizedQuery query, string tableName) returns sql:ParameterizedQuery {
Expand Down
Loading
Loading