Skip to content
Merged
2 changes: 1 addition & 1 deletion ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
distribution = "2201.12.0"
org = "ballerinax"
name = "ai.memory.mssql"
version = "1.1.0"
version = "1.2.0"
license = ["Apache-2.0"]
authors = ["Ballerina"]
keywords = ["ai", "agent", "memory"]
Expand Down
10 changes: 5 additions & 5 deletions ballerina/Dependencies.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ distribution-version = "2201.12.0"
[[package]]
org = "ballerina"
name = "ai"
version = "1.7.0"
version = "1.9.0"
dependencies = [
{org = "ballerina", name = "constraint"},
{org = "ballerina", name = "data.jsondata"},
Expand Down Expand Up @@ -257,7 +257,7 @@ dependencies = [
[[package]]
org = "ballerina"
name = "log"
version = "2.14.0"
version = "2.12.0"
dependencies = [
{org = "ballerina", name = "io"},
{org = "ballerina", name = "jballerina.java"},
Expand Down Expand Up @@ -338,7 +338,7 @@ modules = [
[[package]]
org = "ballerina"
name = "task"
version = "2.11.0"
version = "2.10.0"
dependencies = [
{org = "ballerina", name = "jballerina.java"},
{org = "ballerina", name = "time"},
Expand All @@ -362,7 +362,7 @@ modules = [
[[package]]
org = "ballerina"
name = "time"
version = "2.8.0"
version = "2.7.0"
dependencies = [
{org = "ballerina", name = "jballerina.java"}
]
Expand Down Expand Up @@ -410,7 +410,7 @@ dependencies = [
[[package]]
org = "ballerinax"
name = "ai.memory.mssql"
version = "1.1.0"
version = "1.2.0"
dependencies = [
{org = "ballerina", name = "ai"},
{org = "ballerina", name = "cache"},
Expand Down
118 changes: 115 additions & 3 deletions ballerina/store.bal
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,15 @@ public isolated class ShortTermMemoryStore {
}
}

# Adds a chat message to the memory store for a given key.
# Adds one or more chat messages to the memory store for a given key.
#
# + key - The key associated with the memory
# + message - The `ChatMessage` message to store
# + message - The `ChatMessage` message or messages to store
# + return - nil on success, or an `Error` if the operation fails
public isolated function put(string key, ai:ChatMessage message) returns Error? {
public isolated function put(string key, ai:ChatMessage|ai:ChatMessage[] message) returns Error? {
if message is ai:ChatMessage[] {
return self.putAll(key, message);
}
ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(message);

if dbMessage is ChatSystemMessageDatabaseMessage {
Expand Down Expand Up @@ -245,6 +248,108 @@ public isolated class ShortTermMemoryStore {
cacheEntry.interactiveMessages.push(immutableMessage);
}
}
return;
}

private isolated function putAll(string key, ai:ChatMessage[] messages) returns Error? {
if messages.length() == 0 {
return;
}

// Separate system and interactive messages
final ai:ChatSystemMessage[] systemMsgs = [];
final ai:ChatInteractiveMessage[] interactiveMsgs = [];

foreach ai:ChatMessage msg in messages {
if msg is ai:ChatSystemMessage {
systemMsgs.push(msg);
} else if msg is ai:ChatInteractiveMessage {
interactiveMsgs.push(msg);
}
}

if systemMsgs.length() > 0 {
// Only the last system message is used
ai:ChatSystemMessage lastSystem = systemMsgs[systemMsgs.length() - 1];
ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(lastSystem);

sql:ExecutionResult|sql:Error upsertResult = self.dbClient->execute(
replaceTableNamePlaceholder(`
IF EXISTS (SELECT 1 FROM $_tableName_$ WHERE MessageKey = ${key} AND MessageRole = 'system')
UPDATE $_tableName_$ SET MessageJson = ${dbMessage.toJsonString()}
WHERE MessageKey = ${key} AND MessageRole = 'system'
ELSE
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
VALUES (${key}, ${dbMessage.role}, ${dbMessage.toJsonString()})`,
self.tableName
)
);
if upsertResult is sql:Error {
return error("Failed to upsert system message: " + upsertResult.message(), upsertResult);
}
}

// Insert interactive messages in batch
if interactiveMsgs.length() > 0 {

// Fetch current interactive count
ai:ChatInteractiveMessage[] chatMsgs = check self.getChatInteractiveMessages(key);
int currentCount = chatMsgs.length();
int incoming = interactiveMsgs.length();

if currentCount + incoming > self.maxMessagesPerKey {
return error(string `Cannot add more messages.`
+ string ` Maximum limit '${self.maxMessagesPerKey}' exceeded for key '${key}'`);
}

sql:ParameterizedQuery[] insertQueries = from ai:ChatInteractiveMessage msg in interactiveMsgs
let ChatMessageDatabaseMessage dbMsg = transformToDatabaseMessage(msg)
select replaceTableNamePlaceholder(`
INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson)
VALUES (${key}, ${msg.role}, ${dbMsg.toJsonString()})`,
self.tableName
);

sql:ExecutionResult[]|sql:Error batchResult = self.dbClient->batchExecute(insertQueries);
if batchResult is sql:Error {
return error("Failed batch insert of interactive messages: " + batchResult.message(), batchResult);
}
}

final readonly & ai:ChatSystemMessage? chatSystemMessage;
if systemMsgs.length() > 0 {
ai:ChatSystemMessage lastSystem = systemMsgs[systemMsgs.length() - 1];
final readonly & ai:ChatMessage immutableMessage = mapToImmutableMessage(lastSystem);
if immutableMessage is ai:ChatSystemMessage {
chatSystemMessage = immutableMessage;
} else {
chatSystemMessage = ();
}
} else {
chatSystemMessage = ();
}

final ai:ChatMessage[] & readonly immutableMessages = interactiveMsgs.'map(msg => mapToImmutableMessage(msg))
.cloneReadOnly();

// Update cache
lock {
CachedMessages? cacheEntry = self.getCacheEntry(key);
if cacheEntry is () {
return;
}

if chatSystemMessage is ai:ChatSystemMessage {
cacheEntry.systemMessage = chatSystemMessage;
}

foreach ai:ChatMessage msg in immutableMessages {
if msg is ai:ChatInteractiveMessage {
cacheEntry.interactiveMessages.push(msg);
}
}
}
return;
}

# Removes the system chat message, if specified, for a given key.
Expand Down Expand Up @@ -493,6 +598,13 @@ public isolated class ShortTermMemoryStore {
return checkpanic cacheEntry.ensureType();
}
}

# Retrieves the maximum number of interactive messages that can be stored for each key.
#
# + return - The configured capacity of the message store per key
public isolated function getCapacity() returns int {
return self.maxMessagesPerKey;
}
}

isolated function replaceTableNamePlaceholder(sql:ParameterizedQuery query, string tableName) returns sql:ParameterizedQuery {
Expand Down
Loading
Loading