@@ -43,11 +43,13 @@ def __init__(self,
4343 client : OpenAIClient ,
4444 max_capacity = 2000 ,
4545 embedding_model_name : str = "all-MiniLM-L6-v2" ,
46- embedding_model_kwargs : Optional [dict ] = None ):
46+ embedding_model_kwargs : Optional [dict ] = None ,
47+ llm_model : str = "gpt-4o-mini" ):
4748 self .user_id = user_id
4849 self .client = client
4950 self .max_capacity = max_capacity
5051 self .storage = storage_provider
52+ self .llm_model = llm_model
5153
5254 # Load sessions and other data from the shared storage provider's in-memory metadata
5355 self .sessions : dict = self .storage .get_mid_term_sessions ()
@@ -100,7 +102,7 @@ def add_session(self, summary, details):
100102 ** self .embedding_model_kwargs
101103 )
102104 summary_vec = normalize_vector (summary_vec ).tolist ()
103- summary_keywords = list (extract_keywords_from_multi_summary (summary , client = self .client ))
105+ summary_keywords = list (extract_keywords_from_multi_summary (summary , client = self .client , model = self . llm_model ))
104106
105107 processed_details = []
106108 for page_data in details :
@@ -132,7 +134,7 @@ def add_session(self, summary, details):
132134 else :
133135 print (f"MidTermMemory: Computing new keywords for page { page_id } " )
134136 full_text = f"User: { page_data .get ('user_input' ,'' )} Assistant: { page_data .get ('agent_response' ,'' )} "
135- page_keywords = list (extract_keywords_from_multi_summary (full_text , client = self .client ))
137+ page_keywords = list (extract_keywords_from_multi_summary (full_text , client = self .client , model = self . llm_model ))
136138
137139 processed_page = {
138140 ** page_data , # Carry over existing fields like user_input, agent_response, timestamp
@@ -249,7 +251,7 @@ def insert_pages_into_session(self, summary_for_new_pages, keywords_for_new_page
249251
250252 if "page_keywords" not in page_data or not page_data ["page_keywords" ]:
251253 full_text = f"User: { page_data .get ('user_input' ,'' )} Assistant: { page_data .get ('agent_response' ,'' )} "
252- page_data ["page_keywords" ] = list (extract_keywords_from_multi_summary (full_text , client = self .client ))
254+ page_data ["page_keywords" ] = list (extract_keywords_from_multi_summary (full_text , client = self .client , model = self . llm_model ))
253255
254256 processed_new_pages .append ({** page_data , "page_id" : page_id })
255257
@@ -285,7 +287,7 @@ def search_sessions(self, query_text, segment_similarity_threshold=0.1, page_sim
285287 ** self .embedding_model_kwargs
286288 )
287289 query_vec = normalize_vector (query_vec )
288- query_keywords = set (extract_keywords_from_multi_summary (query_text , client = self .client ))
290+ query_keywords = set (extract_keywords_from_multi_summary (query_text , client = self .client , model = self . llm_model ))
289291
290292 # Search sessions using ChromaDB
291293 similar_sessions = self .storage .search_mid_term_sessions (query_vec .tolist (), top_k = top_k_sessions )
0 commit comments