Skip to content

Commit 8b2e035

Browse files
authored
Merge pull request #290 from ohpauleez/main
Improve CePO capability
2 parents 5211f50 + 36e493c commit 8b2e035

File tree

6 files changed

+267
-236
lines changed

6 files changed

+267
-236
lines changed

optillm/cepo/cepo.py

Lines changed: 78 additions & 54 deletions
Large diffs are not rendered by default.

optillm/cepo/configs/cepo_config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ planning_max_tokens_step3: 4096
1616
planning_max_tokens_step4: 4096
1717
use_plan_diversity: False
1818
rating_model: null
19+
use_reasoning: True
1920
use_reasoning_fallback: False
2021
num_of_retries: 0
21-
print_output: False
22+
print_output: False

optillm/cepo/configs/cepo_config_gptoss.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ planning_max_tokens_step3: 40960
1616
planning_max_tokens_step4: 40960
1717
use_plan_diversity: False
1818
rating_model: null
19+
use_reasoning: True
1920
use_reasoning_fallback: True
2021
num_of_retries: 2
21-
print_output: true
22+
print_output: true

optillm/cepo/configs/cepo_config_qwen3.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ planning_max_tokens_step3: 20481
1616
planning_max_tokens_step4: 20482
1717
use_plan_diversity: False
1818
rating_model: null
19+
use_reasoning: True
1920
use_reasoning_fallback: False
2021
num_of_retries: 0
21-
print_output: False
22+
print_output: False

optillm/conversation_logger.py

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, Any, Optional, List
88
from dataclasses import dataclass, field
99
import time
10+
import copy
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -30,163 +31,163 @@ class ConversationEntry:
3031
class ConversationLogger:
3132
"""
3233
Logger for OptiLLM conversations including all provider interactions and metadata.
33-
34+
3435
Logs are saved in JSONL format (one JSON object per line) with daily rotation.
3536
Each entry contains the full conversation including all intermediate provider calls.
3637
"""
37-
38+
3839
def __init__(self, log_dir: Path, enabled: bool = False):
3940
self.enabled = enabled
4041
self.log_dir = log_dir
4142
self.active_entries: Dict[str, ConversationEntry] = {}
4243
self._lock = threading.Lock()
43-
44+
4445
if self.enabled:
4546
self.log_dir.mkdir(parents=True, exist_ok=True)
4647
logger.info(f"Conversation logging enabled. Logs will be saved to: {self.log_dir}")
4748
else:
4849
logger.debug("Conversation logging disabled")
49-
50+
5051
def _get_log_file_path(self, timestamp: datetime = None) -> Path:
5152
"""Get the log file path for a given timestamp (defaults to now)"""
5253
if timestamp is None:
5354
timestamp = datetime.now(timezone.utc)
5455
date_str = timestamp.strftime("%Y-%m-%d")
5556
return self.log_dir / f"conversations_{date_str}.jsonl"
56-
57+
5758
def _generate_request_id(self) -> str:
5859
"""Generate a unique request ID"""
5960
return f"req_{uuid.uuid4().hex[:8]}"
60-
61-
def start_conversation(self,
62-
client_request: Dict[str, Any],
63-
approach: str,
61+
62+
def start_conversation(self,
63+
client_request: Dict[str, Any],
64+
approach: str,
6465
model: str) -> str:
6566
"""
6667
Start logging a new conversation.
67-
68+
6869
Args:
6970
client_request: The original request from the client
7071
approach: The optimization approach being used
7172
model: The model name
72-
73+
7374
Returns:
7475
str: Unique request ID for this conversation
7576
"""
7677
if not self.enabled:
7778
return ""
78-
79+
7980
request_id = self._generate_request_id()
8081
timestamp = datetime.now(timezone.utc).isoformat()
81-
82+
8283
entry = ConversationEntry(
8384
request_id=request_id,
8485
timestamp=timestamp,
8586
approach=approach,
8687
model=model,
8788
client_request=client_request.copy()
8889
)
89-
90+
9091
with self._lock:
9192
self.active_entries[request_id] = entry
92-
93+
9394
logger.debug(f"Started conversation logging for request {request_id}")
9495
return request_id
95-
96-
def log_provider_call(self,
97-
request_id: str,
98-
provider_request: Dict[str, Any],
96+
97+
def log_provider_call(self,
98+
request_id: str,
99+
provider_request: Dict[str, Any],
99100
provider_response: Dict[str, Any]) -> None:
100101
"""
101102
Log a provider API call and response.
102-
103+
103104
Args:
104105
request_id: The request ID for this conversation
105106
provider_request: The request sent to the provider
106107
provider_response: The response received from the provider
107108
"""
108109
if not self.enabled or not request_id:
109110
return
110-
111+
111112
with self._lock:
112113
entry = self.active_entries.get(request_id)
113114
if not entry:
114115
logger.warning(f"No active conversation found for request {request_id}")
115116
return
116-
117+
117118
call_data = {
118119
"call_number": len(entry.provider_calls) + 1,
119120
"timestamp": datetime.now(timezone.utc).isoformat(),
120-
"request": provider_request.copy(),
121-
"response": provider_response.copy()
121+
"request": provider_request and provider_request.copy() or None,
122+
"response": provider_response and copy.copy(provider_response) or None # Responses are usually strs or dicts
122123
}
123-
124+
124125
entry.provider_calls.append(call_data)
125-
126+
126127
logger.debug(f"Logged provider call #{len(entry.provider_calls)} for request {request_id}")
127-
128-
def log_final_response(self,
129-
request_id: str,
128+
129+
def log_final_response(self,
130+
request_id: str,
130131
final_response: Dict[str, Any]) -> None:
131132
"""
132133
Log the final response sent back to the client.
133-
134+
134135
Args:
135136
request_id: The request ID for this conversation
136137
final_response: The final response sent to the client
137138
"""
138139
if not self.enabled or not request_id:
139140
return
140-
141+
141142
with self._lock:
142143
entry = self.active_entries.get(request_id)
143144
if not entry:
144145
logger.warning(f"No active conversation found for request {request_id}")
145146
return
146-
147+
147148
entry.final_response = final_response.copy()
148149
entry.final_response["timestamp"] = datetime.now(timezone.utc).isoformat()
149-
150+
150151
def log_error(self, request_id: str, error: str) -> None:
151152
"""
152153
Log an error for this conversation.
153-
154+
154155
Args:
155-
request_id: The request ID for this conversation
156+
request_id: The request ID for this conversation
156157
error: Error message or description
157158
"""
158159
if not self.enabled or not request_id:
159160
return
160-
161+
161162
with self._lock:
162163
entry = self.active_entries.get(request_id)
163164
if not entry:
164165
logger.warning(f"No active conversation found for request {request_id}")
165166
return
166-
167+
167168
entry.error = error
168-
169+
169170
logger.debug(f"Logged error for request {request_id}: {error}")
170-
171+
171172
def finalize_conversation(self, request_id: str) -> None:
172173
"""
173174
Finalize and save the conversation to disk.
174-
175+
175176
Args:
176177
request_id: The request ID for this conversation
177178
"""
178179
if not self.enabled or not request_id:
179180
return
180-
181+
181182
with self._lock:
182183
entry = self.active_entries.pop(request_id, None)
183184
if not entry:
184185
logger.warning(f"No active conversation found for request {request_id}")
185186
return
186-
187+
187188
# Calculate total duration
188189
entry.total_duration_ms = int((time.time() - entry.start_time) * 1000)
189-
190+
190191
# Convert to dict for JSON serialization
191192
log_entry = {
192193
"timestamp": entry.timestamp,
@@ -199,12 +200,12 @@ def finalize_conversation(self, request_id: str) -> None:
199200
"total_duration_ms": entry.total_duration_ms,
200201
"error": entry.error
201202
}
202-
203+
203204
# Write to log file
204205
self._write_log_entry(log_entry)
205-
206+
206207
logger.debug(f"Finalized conversation for request {request_id}")
207-
208+
208209
def _write_log_entry(self, log_entry: Dict[str, Any]) -> None:
209210
"""Write a log entry to the appropriate JSONL file"""
210211
try:
@@ -215,18 +216,18 @@ def _write_log_entry(self, log_entry: Dict[str, Any]) -> None:
215216
logger.debug(f"Wrote log entry to {log_file_path}")
216217
except Exception as e:
217218
logger.error(f"Failed to write log entry: {e}")
218-
219+
219220
def get_stats(self) -> Dict[str, Any]:
220221
"""Get statistics about conversation logging"""
221222
with self._lock:
222223
active_count = len(self.active_entries)
223-
224+
224225
stats = {
225226
"enabled": self.enabled,
226227
"log_dir": str(self.log_dir),
227228
"active_conversations": active_count
228229
}
229-
230+
230231
if self.enabled:
231232
# Count total log files and approximate total entries
232233
log_files = list(self.log_dir.glob("conversations_*.jsonl"))
@@ -237,12 +238,12 @@ def get_stats(self) -> Dict[str, Any]:
237238
total_entries += sum(1 for line in f if line.strip())
238239
except Exception:
239240
pass
240-
241+
241242
stats.update({
242243
"log_files_count": len(log_files),
243244
"total_entries_approximate": total_entries
244245
})
245-
246+
246247
return stats
247248

248249

@@ -262,4 +263,4 @@ def log_provider_call(request_id: str, provider_request: Dict[str, Any], provide
262263
def log_error(request_id: str, error_message: str) -> None:
263264
"""Log an error using the global logger instance"""
264265
if _global_logger and _global_logger.enabled:
265-
_global_logger.log_error(request_id, error_message)
266+
_global_logger.log_error(request_id, error_message)

0 commit comments

Comments
 (0)