Skip to content
Merged
2 changes: 1 addition & 1 deletion ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
distribution = "2201.12.0"
org = "ballerinax"
name = "ai.memory.mssql"
version = "1.1.0"
version = "1.2.0"
license = ["Apache-2.0"]
authors = ["Ballerina"]
keywords = ["ai", "agent", "memory"]
Expand Down
8 changes: 4 additions & 4 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ distribution-version = "2201.12.0"
[[package]]
org = "ballerina"
name = "ai"
version = "1.7.0"
version = "1.9.0"
dependencies = [
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "data.jsondata"},
Expand Down Expand Up @@ -75,7 +75,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "crypto"
version = "2.9.2"
version = "2.10.1"
dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "time"}
Expand Down Expand Up @@ -410,7 +410,7 @@ dependencies = [
[[package]]
org = "ballerinax"
name = "ai.memory.mssql"
version = "1.1.0"
version = "1.2.0"
dependencies = [
{org = "ballerina", name = "ai"},
{org = "ballerina", name = "cache"},
Expand Down Expand Up @@ -438,7 +438,7 @@ dependencies = [
[[package]]
org = "ballerinax"
name = "mssql"
version = "1.16.1"
version = "1.16.2"
dependencies = [
{org = "ballerina", name = "crypto"},
{org = "ballerina", name = "jballerina.java"},
Expand Down
161 changes: 115 additions & 46 deletions ballerina/store.bal
Original file line number Diff line number Diff line change
Expand Up @@ -183,44 +183,23 @@ public isolated class ShortTermMemoryStore {
}
}

# Adds a chat message to the memory store for a given key.
# Adds one or more chat messages to the memory store for a given key.
#
# + key - The key associated with the memory
# + message - The `ChatMessage` message to store
# + message - The `ChatMessage` message or messages to store
# + return - nil on success, or an `Error` if the operation fails
public isolated function put(string key, ai:ChatMessage message) returns Error? {
public isolated function put(string key, ai:ChatMessage|ai:ChatMessage[] message) returns Error? {
if message is ai:ChatMessage[] {
return self.putAll(key, message);
}
ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(message);

if dbMessage is ChatSystemMessageDatabaseMessage {
// Upsert system message for the key
sql:ExecutionResult|sql:Error upsertResult = self.dbClient->execute(
replaceTableNamePlaceholder(`
IF EXISTS (SELECT 1 FROM $_tableName_$ WHERE MessageKey = ${key} AND MessageRole = 'system')
UPDATE $_tableName_$ SET MessageJson = ${dbMessage.toJsonString()}
WHERE MessageKey = ${key} AND MessageRole = 'system'
ELSE
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
VALUES (${key}, ${dbMessage.role}, ${dbMessage.toJsonString()})`,
self.tableName
)
);
sql:ExecutionResult|sql:Error upsertResult = self.updateSystemMessage(key, dbMessage);
if upsertResult is sql:Error {
return error("Failed to upsert system message: " + upsertResult.message(), upsertResult);
}
} else {
do {
// Expected to be checked by the caller, but doing an additional check here, without locking.
ai:ChatInteractiveMessage[]|Error chatInteractiveMessages = self.getChatInteractiveMessages(key);
if chatInteractiveMessages is Error {
error? cause = chatInteractiveMessages.cause();
fail cause is error ? cause : chatInteractiveMessages;
}

if chatInteractiveMessages.length() >= self.maxMessagesPerKey {
return error(string `Cannot add more messages. Maximum limit of '${
self.maxMessagesPerKey}' reached for key: '${key}'`);
}

_ = check self.dbClient->execute(
replaceTableNamePlaceholder(`
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
Expand All @@ -247,6 +226,80 @@ public isolated class ShortTermMemoryStore {
}
}

private isolated function putAll(string key, ai:ChatMessage[] messages) returns Error? {
if messages.length() == 0 {
return;
}

final var [newSystemMessages, newInteractiveMessages] = partitionMessagesByType(messages);
final readonly & ai:ChatSystemMessage? finalChatSystemMessage = getLatestSystemMessage(newSystemMessages);
if finalChatSystemMessage is ai:ChatSystemMessage {
ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(finalChatSystemMessage);
sql:ExecutionResult|sql:Error upsertResult = self.updateSystemMessage(key, dbMessage);
if upsertResult is sql:Error {
return error("Failed to upsert system message: " + upsertResult.message(), upsertResult);
}
}

// Insert interactive messages in batch
if newInteractiveMessages.length() > 0 {
ai:ChatInteractiveMessage[] oldInteractiveMesssages = check self.getChatInteractiveMessages(key);
int currentCount = oldInteractiveMesssages.length();
int incoming = newInteractiveMessages.length();

if currentCount + incoming > self.maxMessagesPerKey {
return error(string `Cannot add more messages.`
+ string ` Maximum limit '${self.maxMessagesPerKey}' exceeded for key '${key}'`);
}
sql:ParameterizedQuery[] insertQueries = from ai:ChatInteractiveMessage msg in newInteractiveMessages
let ChatMessageDatabaseMessage dbMsg = transformToDatabaseMessage(msg)
select replaceTableNamePlaceholder(`
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
VALUES (${key}, ${msg.role}, ${dbMsg.toJsonString()})`,
self.tableName
);
sql:ExecutionResult[]|sql:Error batchResult = self.dbClient->batchExecute(insertQueries);
if batchResult is sql:Error {
return error("Failed batch insert of interactive messages: " + batchResult.message(), batchResult);
}
}

final ai:ChatInteractiveMessage[] & readonly immutableInteractiveMessages = from ai:ChatInteractiveMessage message
in newInteractiveMessages
select <readonly & ai:ChatInteractiveMessage>mapToImmutableMessage(message);
self.updateCache(key, finalChatSystemMessage, immutableInteractiveMessages);
}

private isolated function updateCache(string key, readonly & ai:ChatSystemMessage? systemMessage,
readonly & ai:ChatInteractiveMessage[] interactiveMessages) {
lock {
CachedMessages? cacheEntry = self.getCacheEntry(key);
if cacheEntry is () {
return;
}
if systemMessage is ai:ChatSystemMessage {
cacheEntry.systemMessage = systemMessage;
}
cacheEntry.interactiveMessages.push(...interactiveMessages);
}
return;
}

private isolated function updateSystemMessage(string key, ChatMessageDatabaseMessage systemMessage)
returns sql:ExecutionResult|sql:Error {
return self.dbClient->execute(
replaceTableNamePlaceholder(`
IF EXISTS (SELECT 1 FROM $_tableName_$ WHERE MessageKey = ${key} AND MessageRole = 'system')
UPDATE $_tableName_$ SET MessageJson = ${systemMessage.toJsonString()}
WHERE MessageKey = ${key} AND MessageRole = 'system'
ELSE
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
VALUES (${key}, ${systemMessage.role}, ${systemMessage.toJsonString()})`,
self.tableName
)
);
}

# Removes the system chat message, if specified, for a given key.
#
# + key - The key associated with the memory
Expand Down Expand Up @@ -401,24 +454,6 @@ public isolated class ShortTermMemoryStore {

private isolated function cacheFromDatabase(string key)
returns readonly & ([ai:ChatSystemMessage, ai:ChatInteractiveMessage...]|ai:ChatInteractiveMessage[])|Error {
int|sql:Error messageCount = self.dbClient->queryRow(
replaceTableNamePlaceholder(`
SELECT COUNT(*) as count
FROM $_tableName_$
WHERE MessageKey = ${key} AND MessageRole != 'system'`,
self.tableName
)
);

if messageCount is sql:Error {
return error("Failed to load message count from the database: " + messageCount.message(), messageCount);
}

if messageCount > self.maxMessagesPerKey {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need this still?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do't need to do that. It should be handled by the caller as discussed previously with @shafreenAnfar and @SasinduDilshara.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For updates, yes. But this gets called during a get too, right? I'm not sure we can expect users to do this check for gets also.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if updates are already handled by the user, that guarantees the memory will always have a message count below the configured size. I do not think we need to handle this in get.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't happen for the first get that happens if no update has happened, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline with @MaryamZi, there can be cases where a memory instance is created with a specific size configuration, for example 10, and updates are persisted accordingly, so the database contains 10 entries. Later, if the user code is changed to use a smaller size, the database may still hold more entries than the new limit, which could result in an error.

Since this is a very rare scenario, we are currently removing the size check to reduce database calls.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the cache is used, there'll be less database calls anyway, so one here would be okay IMO.

return error ExceedsSizeError(string `Cannot load messages from the database: Message count '${
messageCount}' exceeds maximum limit of '${self.maxMessagesPerKey}' for key: '${key}'`);
}

do {
stream<DatabaseRecord, sql:Error?> messages = self.dbClient->query(
replaceTableNamePlaceholder(`
Expand Down Expand Up @@ -493,6 +528,13 @@ public isolated class ShortTermMemoryStore {
return checkpanic cacheEntry.ensureType();
}
}

# Retrieves the maximum number of interactive messages that can be stored for each key.
#
# + return - The configured capacity of the message store per key
public isolated function getCapacity() returns int {
return self.maxMessagesPerKey;
}
}

isolated function replaceTableNamePlaceholder(sql:ParameterizedQuery query, string tableName) returns sql:ParameterizedQuery {
Expand All @@ -501,3 +543,30 @@ isolated function replaceTableNamePlaceholder(sql:ParameterizedQuery query, stri
query.strings = strings;
return query;
}

isolated function partitionMessagesByType(ai:ChatMessage[] messages)
returns [ai:ChatSystemMessage[], ai:ChatInteractiveMessage[]] {
ai:ChatSystemMessage[] systemMsgs = [];
ai:ChatInteractiveMessage[] interactiveMsgs = [];
foreach ai:ChatMessage msg in messages {
if msg is ai:ChatSystemMessage {
systemMsgs.push(msg);
} else if msg is ai:ChatInteractiveMessage {
interactiveMsgs.push(msg);
}
}
return [systemMsgs, interactiveMsgs];
}

isolated function getLatestSystemMessage(ai:ChatSystemMessage[] systemMessages)
returns readonly & ai:ChatSystemMessage? {
if systemMessages.length() == 0 {
return;
}
ai:ChatSystemMessage lastSystemMessage = systemMessages[systemMessages.length() - 1];
readonly & ai:ChatMessage immutableMessage = mapToImmutableMessage(lastSystemMessage);
if immutableMessage is ai:ChatSystemMessage {
return immutableMessage;
}
return;
}
Loading
Loading