-
Notifications
You must be signed in to change notification settings - Fork 3
Add support to batch insert messages #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
63ffd82
3bfc06f
9feabad
bad5809
bd9a4e7
dfcd998
0d8e952
f6ef669
5644237
306357b
286d07d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -183,12 +183,15 @@ public isolated class ShortTermMemoryStore { | |
| } | ||
| } | ||
|
|
||
| # Adds a chat message to the memory store for a given key. | ||
| # Adds one or more chat messages to the memory store for a given key. | ||
| # | ||
| # + key - The key associated with the memory | ||
| # + message - The `ChatMessage` message to store | ||
| # + message - The `ChatMessage` message or messages to store | ||
| # + return - nil on success, or an `Error` if the operation fails | ||
| public isolated function put(string key, ai:ChatMessage message) returns Error? { | ||
| public isolated function put(string key, ai:ChatMessage|ai:ChatMessage[] message) returns Error? { | ||
| if message is ai:ChatMessage[] { | ||
| return self.putAll(key, message); | ||
| } | ||
| ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(message); | ||
|
|
||
| if dbMessage is ChatSystemMessageDatabaseMessage { | ||
|
|
@@ -209,18 +212,6 @@ public isolated class ShortTermMemoryStore { | |
| } | ||
| } else { | ||
| do { | ||
| // Expected to be checked by the caller, but doing an additional check here, without locking. | ||
| ai:ChatInteractiveMessage[]|Error chatInteractiveMessages = self.getChatInteractiveMessages(key); | ||
| if chatInteractiveMessages is Error { | ||
| error? cause = chatInteractiveMessages.cause(); | ||
| fail cause is error ? cause : chatInteractiveMessages; | ||
| } | ||
|
|
||
| if chatInteractiveMessages.length() >= self.maxMessagesPerKey { | ||
| return error(string `Cannot add more messages. Maximum limit of '${ | ||
| self.maxMessagesPerKey}' reached for key: '${key}'`); | ||
| } | ||
|
|
||
| _ = check self.dbClient->execute( | ||
| replaceTableNamePlaceholder(` | ||
| INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson) | ||
|
|
@@ -245,6 +236,108 @@ public isolated class ShortTermMemoryStore { | |
| cacheEntry.interactiveMessages.push(immutableMessage); | ||
| } | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| private isolated function putAll(string key, ai:ChatMessage[] messages) returns Error? { | ||
| if messages.length() == 0 { | ||
| return; | ||
| } | ||
|
|
||
| // Separate system and interactive messages | ||
| final ai:ChatSystemMessage[] systemMsgs = []; | ||
| final ai:ChatInteractiveMessage[] interactiveMsgs = []; | ||
|
|
||
| foreach ai:ChatMessage msg in messages { | ||
| if msg is ai:ChatSystemMessage { | ||
| systemMsgs.push(msg); | ||
| } else if msg is ai:ChatInteractiveMessage { | ||
| interactiveMsgs.push(msg); | ||
| } | ||
| } | ||
|
|
||
| if systemMsgs.length() > 0 { | ||
| // Only the last system message is used | ||
| ai:ChatSystemMessage lastSystem = systemMsgs[systemMsgs.length() - 1]; | ||
| ChatMessageDatabaseMessage dbMessage = transformToDatabaseMessage(lastSystem); | ||
|
|
||
| sql:ExecutionResult|sql:Error upsertResult = self.dbClient->execute( | ||
| replaceTableNamePlaceholder(` | ||
| IF EXISTS (SELECT 1 FROM $_tableName_$ WHERE MessageKey = ${key} AND MessageRole = 'system') | ||
| UPDATE $_tableName_$ SET MessageJson = ${dbMessage.toJsonString()} | ||
| WHERE MessageKey = ${key} AND MessageRole = 'system' | ||
| ELSE | ||
| INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson) | ||
| VALUES (${key}, ${dbMessage.role}, ${dbMessage.toJsonString()})`, | ||
| self.tableName | ||
| ) | ||
| ); | ||
| if upsertResult is sql:Error { | ||
| return error("Failed to upsert system message: " + upsertResult.message(), upsertResult); | ||
| } | ||
| } | ||
|
|
||
| // Insert interactive messages in batch | ||
| if interactiveMsgs.length() > 0 { | ||
|
|
||
| // Fetch current interactive count | ||
| ai:ChatInteractiveMessage[] chatMsgs = check self.getChatInteractiveMessages(key); | ||
| int currentCount = chatMsgs.length(); | ||
| int incoming = interactiveMsgs.length(); | ||
|
|
||
| if currentCount + incoming > self.maxMessagesPerKey { | ||
| return error(string `Cannot add more messages.` | ||
| + string ` Maximum limit '${self.maxMessagesPerKey}' exceeded for key '${key}'`); | ||
| } | ||
|
|
||
| sql:ParameterizedQuery[] insertQueries = from ai:ChatInteractiveMessage msg in interactiveMsgs | ||
| let ChatMessageDatabaseMessage dbMsg = transformToDatabaseMessage(msg) | ||
| select replaceTableNamePlaceholder(` | ||
| INSERT INTO $_tableName_$ (MessageKey, MessageRole, MessageJson) | ||
| VALUES (${key}, ${msg.role}, ${dbMsg.toJsonString()})`, | ||
| self.tableName | ||
| ); | ||
|
|
||
| sql:ExecutionResult[]|sql:Error batchResult = self.dbClient->batchExecute(insertQueries); | ||
| if batchResult is sql:Error { | ||
| return error("Failed batch insert of interactive messages: " + batchResult.message(), batchResult); | ||
| } | ||
| } | ||
|
|
||
| final readonly & ai:ChatSystemMessage? chatSystemMessage; | ||
| if systemMsgs.length() > 0 { | ||
| ai:ChatSystemMessage lastSystem = systemMsgs[systemMsgs.length() - 1]; | ||
| final readonly & ai:ChatMessage immutableMessage = mapToImmutableMessage(lastSystem); | ||
| if immutableMessage is ai:ChatSystemMessage { | ||
| chatSystemMessage = immutableMessage; | ||
| } else { | ||
| chatSystemMessage = (); | ||
| } | ||
| } else { | ||
| chatSystemMessage = (); | ||
| } | ||
|
|
||
| final ai:ChatMessage[] & readonly immutableMessages = interactiveMsgs.'map(msg => mapToImmutableMessage(msg)) | ||
| .cloneReadOnly(); | ||
|
|
||
| // Update cache | ||
| lock { | ||
| CachedMessages? cacheEntry = self.getCacheEntry(key); | ||
| if cacheEntry is () { | ||
| return; | ||
| } | ||
|
|
||
| if chatSystemMessage is ai:ChatSystemMessage { | ||
| cacheEntry.systemMessage = chatSystemMessage; | ||
| } | ||
|
|
||
| foreach ai:ChatMessage msg in immutableMessages { | ||
| if msg is ai:ChatInteractiveMessage { | ||
| cacheEntry.interactiveMessages.push(msg); | ||
| } | ||
| } | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| # Removes the system chat message, if specified, for a given key. | ||
|
|
@@ -401,24 +494,6 @@ public isolated class ShortTermMemoryStore { | |
|
|
||
| private isolated function cacheFromDatabase(string key) | ||
| returns readonly & ([ai:ChatSystemMessage, ai:ChatInteractiveMessage...]|ai:ChatInteractiveMessage[])|Error { | ||
| int|sql:Error messageCount = self.dbClient->queryRow( | ||
| replaceTableNamePlaceholder(` | ||
| SELECT COUNT(*) as count | ||
| FROM $_tableName_$ | ||
| WHERE MessageKey = ${key} AND MessageRole != 'system'`, | ||
| self.tableName | ||
| ) | ||
| ); | ||
|
|
||
| if messageCount is sql:Error { | ||
| return error("Failed to load message count from the database: " + messageCount.message(), messageCount); | ||
| } | ||
|
|
||
| if messageCount > self.maxMessagesPerKey { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need this still?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do't need to do that. It should be handled by the caller as discussed previously with @shafreenAnfar and @SasinduDilshara.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For updates, yes. But this gets called during a get too, right? I'm not sure we can expect users to do this check for gets also.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if updates are already handled by the user, that guarantees the memory will always have a message count below the configured size. I do not think we need to handle this in get.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't happen for the first get that happens if no update has happened, right?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline with @MaryamZi, there can be cases where a memory instance is created with a specific size configuration, for example 10, and updates are persisted accordingly, so the database contains 10 entries. Later, if the user code is changed to use a smaller size, the database may still hold more entries than the new limit, which could result in an error. Since this is a very rare scenario, we are currently removing the size check to reduce database calls.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the cache is used, there'll be less database calls anyway, so one here would be okay IMO. |
||
| return error ExceedsSizeError(string `Cannot load messages from the database: Message count '${ | ||
| messageCount}' exceeds maximum limit of '${self.maxMessagesPerKey}' for key: '${key}'`); | ||
| } | ||
|
|
||
| do { | ||
| stream<DatabaseRecord, sql:Error?> messages = self.dbClient->query( | ||
| replaceTableNamePlaceholder(` | ||
|
|
@@ -493,6 +568,13 @@ public isolated class ShortTermMemoryStore { | |
| return checkpanic cacheEntry.ensureType(); | ||
| } | ||
| } | ||
|
|
||
| # Retrieves the maximum number of interactive messages that can be stored for each key. | ||
| # | ||
| # + return - The configured capacity of the message store per key | ||
| public isolated function getCapacity() returns int { | ||
| return self.maxMessagesPerKey; | ||
| } | ||
| } | ||
|
|
||
| isolated function replaceTableNamePlaceholder(sql:ParameterizedQuery query, string tableName) returns sql:ParameterizedQuery { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.