-
Notifications
You must be signed in to change notification settings - Fork 3
Add support to batch insert messages #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
63ffd82
3bfc06f
9feabad
bad5809
bd9a4e7
dfcd998
0d8e952
f6ef669
5644237
306357b
286d07d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -183,44 +183,23 @@ public isolated class ShortTermMemoryStore { | |
| } | ||
| } | ||
|
|
||
| # Adds a chat message to the memory store for a given key. | ||
| # Adds one or more chat messages to the memory store for a given key. | ||
| # | ||
| # + key - The key associated with the memory | ||
| # + message - The `ChatMessage` message to store | ||
| # + message - The `ChatMessage` message or messages to store | ||
| # + return - nil on success, or an `Error` if the operation fails | ||
| public isolated function put(string key, ai:ChatMessage message) returns Error? { | ||
| public isolated function put(string key, ai:ChatMessage|ai:ChatMessage[] message) returns Error? { | ||
| if message is ai:ChatMessage[] { | ||
| return self.putAll(key, message); | ||
| } | ||
| ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(message); | ||
|
|
||
| if dbMessage is ChatSystemMessageDatabaseMessage { | ||
| // Upsert system message for the key | ||
| sql:ExecutionResult|sql:Error upsertResult = self.dbClient->execute( | ||
| replaceTableNamePlaceholder(` | ||
| IF EXISTS (SELECT 1 FROM $_tableName_$ WHERE MessageKey = ${key} AND MessageRole = 'system') | ||
| UPDATE $_tableName_$ SET MessageJson = ${dbMessage.toJsonString()} | ||
| WHERE MessageKey = ${key} AND MessageRole = 'system' | ||
| ELSE | ||
| INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson) | ||
| VALUES (${key}, ${dbMessage.role}, ${dbMessage.toJsonString()})`, | ||
| self.tableName | ||
| ) | ||
| ); | ||
| sql:ExecutionResult|sql:Error upsertResult = self.updateSystemMessage(key, dbMessage); | ||
| if upsertResult is sql:Error { | ||
| return error("Failed to upsert system message: " + upsertResult.message(), upsertResult); | ||
| } | ||
| } else { | ||
| do { | ||
| // Expected to be checked by the caller, but doing an additional check here, without locking. | ||
| ai:ChatInteractiveMessage[]|Error chatInteractiveMessages = self.getChatInteractiveMessages(key); | ||
| if chatInteractiveMessages is Error { | ||
| error? cause = chatInteractiveMessages.cause(); | ||
| fail cause is error ? cause : chatInteractiveMessages; | ||
| } | ||
|
|
||
| if chatInteractiveMessages.length() >= self.maxMessagesPerKey { | ||
| return error(string `Cannot add more messages. Maximum limit of '${ | ||
| self.maxMessagesPerKey}' reached for key: '${key}'`); | ||
| } | ||
|
|
||
| _ = check self.dbClient->execute( | ||
| replaceTableNamePlaceholder(` | ||
| INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson) | ||
|
|
@@ -247,6 +226,80 @@ public isolated class ShortTermMemoryStore { | |
| } | ||
| } | ||
|
|
||
| private isolated function putAll(string key, ai:ChatMessage[] messages) returns Error? { | ||
| if messages.length() == 0 { | ||
| return; | ||
| } | ||
|
|
||
| final var [newSystemMessages, newInteractiveMessages] = partitionMessagesByType(messages); | ||
| final readonly & ai:ChatSystemMessage? finalChatSystemMessage = getLatestSystemMessage(newSystemMessages); | ||
| if finalChatSystemMessage is ai:ChatSystemMessage { | ||
| ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(finalChatSystemMessage); | ||
| sql:ExecutionResult|sql:Error upsertResult = self.updateSystemMessage(key, dbMessage); | ||
| if upsertResult is sql:Error { | ||
| return error("Failed to upsert system message: " + upsertResult.message(), upsertResult); | ||
| } | ||
| } | ||
|
|
||
| // Insert interactive messages in batch | ||
| if newInteractiveMessages.length() > 0 { | ||
| ai:ChatInteractiveMessage[] oldInteractiveMesssages = check self.getChatInteractiveMessages(key); | ||
| int currentCount = oldInteractiveMesssages.length(); | ||
| int incoming = newInteractiveMessages.length(); | ||
|
|
||
| if currentCount + incoming > self.maxMessagesPerKey { | ||
| return error(string `Cannot add more messages.` | ||
| + string ` Maximum limit '${self.maxMessagesPerKey}' exceeded for key '${key}'`); | ||
| } | ||
| sql:ParameterizedQuery[] insertQueries = from ai:ChatInteractiveMessage msg in newInteractiveMessages | ||
| let ChatMessageDatabaseMessage dbMsg = transformToDatabaseMessage(msg) | ||
| select replaceTableNamePlaceholder(` | ||
| INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson) | ||
| VALUES (${key}, ${msg.role}, ${dbMsg.toJsonString()})`, | ||
| self.tableName | ||
| ); | ||
| sql:ExecutionResult[]|sql:Error batchResult = self.dbClient->batchExecute(insertQueries); | ||
| if batchResult is sql:Error { | ||
| return error("Failed batch insert of interactive messages: " + batchResult.message(), batchResult); | ||
| } | ||
| } | ||
|
|
||
| final ai:ChatInteractiveMessage[] & readonly immutableInteractiveMessages = from ai:ChatInteractiveMessage message | ||
| in newInteractiveMessages | ||
| select <readonly & ai:ChatInteractiveMessage>mapToImmutableMessage(message); | ||
| self.updateCache(key, finalChatSystemMessage, immutableInteractiveMessages); | ||
| } | ||
|
|
||
| private isolated function updateCache(string key, readonly & ai:ChatSystemMessage? systemMessage, | ||
| readonly & ai:ChatInteractiveMessage[] interactiveMessages) { | ||
| lock { | ||
| CachedMessages? cacheEntry = self.getCacheEntry(key); | ||
| if cacheEntry is () { | ||
| return; | ||
| } | ||
| if systemMessage is ai:ChatSystemMessage { | ||
| cacheEntry.systemMessage = systemMessage; | ||
| } | ||
| cacheEntry.interactiveMessages.push(...interactiveMessages); | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| private isolated function updateSystemMessage(string key, ChatMessageDatabaseMessage systemMessage) | ||
| returns sql:ExecutionResult|sql:Error { | ||
| return self.dbClient->execute( | ||
| replaceTableNamePlaceholder(` | ||
| IF EXISTS (SELECT 1 FROM $_tableName_$ WHERE MessageKey = ${key} AND MessageRole = 'system') | ||
| UPDATE $_tableName_$ SET MessageJson = ${systemMessage.toJsonString()} | ||
| WHERE MessageKey = ${key} AND MessageRole = 'system' | ||
| ELSE | ||
| INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson) | ||
| VALUES (${key}, ${systemMessage.role}, ${systemMessage.toJsonString()})`, | ||
| self.tableName | ||
| ) | ||
| ); | ||
| } | ||
|
|
||
| # Removes the system chat message, if specified, for a given key. | ||
| # | ||
| # + key - The key associated with the memory | ||
|
|
@@ -401,24 +454,6 @@ public isolated class ShortTermMemoryStore { | |
|
|
||
| private isolated function cacheFromDatabase(string key) | ||
| returns readonly & ([ai:ChatSystemMessage, ai:ChatInteractiveMessage...]|ai:ChatInteractiveMessage[])|Error { | ||
| int|sql:Error messageCount = self.dbClient->queryRow( | ||
| replaceTableNamePlaceholder(` | ||
| SELECT COUNT(*) as count | ||
| FROM $_tableName_$ | ||
| WHERE MessageKey = ${key} AND MessageRole != 'system'`, | ||
| self.tableName | ||
| ) | ||
| ); | ||
|
|
||
| if messageCount is sql:Error { | ||
| return error("Failed to load message count from the database: " + messageCount.message(), messageCount); | ||
| } | ||
|
|
||
| if messageCount > self.maxMessagesPerKey { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need this still?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do't need to do that. It should be handled by the caller as discussed previously with @shafreenAnfar and @SasinduDilshara.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For updates, yes. But this gets called during a get too, right? I'm not sure we can expect users to do this check for gets also.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if updates are already handled by the user, that guarantees the memory will always have a message count below the configured size. I do not think we need to handle this in get.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't happen for the first get that happens if no update has happened, right?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline with @MaryamZi, there can be cases where a memory instance is created with a specific size configuration, for example 10, and updates are persisted accordingly, so the database contains 10 entries. Later, if the user code is changed to use a smaller size, the database may still hold more entries than the new limit, which could result in an error. Since this is a very rare scenario, we are currently removing the size check to reduce database calls.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the cache is used, there'll be less database calls anyway, so one here would be okay IMO. |
||
| return error ExceedsSizeError(string `Cannot load messages from the database: Message count '${ | ||
| messageCount}' exceeds maximum limit of '${self.maxMessagesPerKey}' for key: '${key}'`); | ||
| } | ||
|
|
||
| do { | ||
| stream<DatabaseRecord, sql:Error?> messages = self.dbClient->query( | ||
| replaceTableNamePlaceholder(` | ||
|
|
@@ -493,6 +528,13 @@ public isolated class ShortTermMemoryStore { | |
| return checkpanic cacheEntry.ensureType(); | ||
| } | ||
| } | ||
|
|
||
| # Retrieves the maximum number of interactive messages that can be stored for each key. | ||
| # | ||
| # + return - The configured capacity of the message store per key | ||
| public isolated function getCapacity() returns int { | ||
| return self.maxMessagesPerKey; | ||
| } | ||
| } | ||
|
|
||
| isolated function replaceTableNamePlaceholder(sql:ParameterizedQuery query, string tableName) returns sql:ParameterizedQuery { | ||
|
|
@@ -501,3 +543,30 @@ isolated function replaceTableNamePlaceholder(sql:ParameterizedQuery query, stri | |
| query.strings = strings; | ||
| return query; | ||
| } | ||
|
|
||
| isolated function partitionMessagesByType(ai:ChatMessage[] messages) | ||
| returns [ai:ChatSystemMessage[], ai:ChatInteractiveMessage[]] { | ||
| ai:ChatSystemMessage[] systemMsgs = []; | ||
| ai:ChatInteractiveMessage[] interactiveMsgs = []; | ||
| foreach ai:ChatMessage msg in messages { | ||
| if msg is ai:ChatSystemMessage { | ||
| systemMsgs.push(msg); | ||
| } else if msg is ai:ChatInteractiveMessage { | ||
| interactiveMsgs.push(msg); | ||
| } | ||
| } | ||
| return [systemMsgs, interactiveMsgs]; | ||
| } | ||
|
|
||
| isolated function getLatestSystemMessage(ai:ChatSystemMessage[] systemMessages) | ||
| returns readonly & ai:ChatSystemMessage? { | ||
| if systemMessages.length() == 0 { | ||
| return; | ||
| } | ||
| ai:ChatSystemMessage lastSystemMessage = systemMessages[systemMessages.length() - 1]; | ||
| readonly & ai:ChatMessage immutableMessage = mapToImmutableMessage(lastSystemMessage); | ||
| if immutableMessage is ai:ChatSystemMessage { | ||
| return immutableMessage; | ||
| } | ||
| return; | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.