1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import logging
15+ import sys
16+ import traceback
1517from typing import Any
1618
1719import aiohttp
4042
4143
4244class _VLLMChatCompletions (AsyncCompletions ):
43- """Wraps vllm_model and injects token IDs as attributes for verifiers. """
45+ """adapt vllm_model format to verifiers expected format """
4446 def __init__ (self , base_url : str ) -> None :
4547 self ._base_url = base_url .rstrip ("/" )
4648
@@ -53,19 +55,33 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion:
5355 if key in kwargs and kwargs [key ] is not None :
5456 request_body [key ] = kwargs [key ]
5557
56- async with aiohttp .ClientSession () as session :
57- async with session .post (f"{ self ._base_url } /chat/completions" , json = request_body ) as resp :
58- resp .raise_for_status ()
59- response_dict = await resp .json ()
58+ url = f"{ self ._base_url } /chat/completions"
59+ try :
60+ async with aiohttp .ClientSession () as session :
61+ async with session .post (url , json = request_body ) as resp :
62+ if resp .status != 200 :
63+ error_text = await resp .text ()
64+ logger .error (f"[verifiers_agent] Request to { url } failed with status { resp .status } : { error_text [:500 ]} " )
65+ resp .raise_for_status ()
66+ response_dict = await resp .json ()
67+ except Exception as e :
68+ logger .error (f"[verifiers_agent] Exception calling { url } : { type (e ).__name__ } : { e } " )
69+ raise
6070
61- # Extract token IDs from vllm_model
6271 choice_dict = response_dict ["choices" ][0 ]
6372 message_dict = choice_dict .get ("message" , {})
73+
74+
6475 prompt_token_ids = message_dict .pop ("prompt_token_ids" , [])
6576 generation_token_ids = message_dict .pop ("generation_token_ids" , [])
6677 generation_log_probs = message_dict .pop ("generation_log_probs" , [])
6778
68- # Reconstruct logprobs.content for verifiers
79+ if not generation_token_ids :
80+ logger .warning (f"[verifiers_agent] No generation_token_ids in response! Full message keys were: { list (choice_dict .get ('message' , {}).keys ())} " )
81+
82+ if generation_token_ids and isinstance (generation_token_ids [0 ], str ):
83+ generation_token_ids = [int (tid ) for tid in generation_token_ids ]
84+
6985 if generation_token_ids and generation_log_probs :
7086 choice_dict ["logprobs" ] = {
7187 "content" : [
@@ -116,6 +132,7 @@ class VerifiersAgentConfig(BaseResponsesAPIAgentConfig):
116132
117133 max_tokens : int = Field (default = 512 , description = "Max tokens for generation" )
118134 temperature : float = Field (default = 1.0 , description = "Sampling temperature" )
135+ top_p : float = Field (default = 1.0 , description = "Top-p sampling" )
119136
120137
121138class VerifiersAgentRunRequest (BaseRunRequest ):
@@ -129,7 +146,7 @@ class VerifiersAgentRunRequest(BaseRunRequest):
129146 answer : str = Field (default = "" , description = "Expected answer" )
130147 task : str = Field (default = "default" , description = "Task type" )
131148 example_id : int | str = Field (default = 0 , description = "Example ID" )
132- info : dict = Field (default_factory = dict , description = "Extra info for scoring (e.g., ifeval constraints) " )
149+ info : dict = Field (default_factory = dict , description = "Extra info for scoring" )
133150
134151
135152_ENVS_CACHE : dict [str , vf .Environment ] = {}
@@ -162,7 +179,6 @@ async def _ensure_env_loaded(self, vf_env_id: str) -> tuple[vf.Environment, str,
162179 ds = getattr (vf_env , attr , None )
163180 if ds is not None :
164181 dataset = ds
165- logger .info (f"Found dataset in vf_env.{ attr } " )
166182 break
167183 if dataset is None :
168184 raise ValueError (f"Environment { vf_env_id } does not have a dataset" )
@@ -203,7 +219,6 @@ def _get_openai_client(self) -> VLLMOpenAIClient:
203219 model_server_url = model_server_url .rstrip ("/" ) + "/v1"
204220
205221 _OPENAI_CLIENT_CACHE [cache_key ] = VLLMOpenAIClient (base_url = model_server_url )
206- logger .info (f"Created VLLMOpenAIClient pointing to: { model_server_url } " )
207222
208223 return _OPENAI_CLIENT_CACHE [cache_key ]
209224
@@ -220,97 +235,97 @@ def _convert_trajectory_to_output(self, state: dict) -> list:
220235 trajectory = state .get ("trajectory" , [])
221236
222237 for step in trajectory :
223- step_output = []
224-
225238 for msg in step .get ("prompt" , []):
226239 if isinstance (msg , dict ):
227240 role = msg .get ("role" , "user" )
228241 content = msg .get ("content" , "" )
229- step_output .append (NeMoGymEasyInputMessage (role = role , content = content ))
242+ output .append (NeMoGymEasyInputMessage (role = role , content = content ). model_dump ( ))
230243
231244 tokens = step .get ("tokens" )
232245 for msg in step .get ("completion" , []):
233246 if isinstance (msg , dict ):
234247 content = msg .get ("content" , "" )
235248 if tokens :
236- step_output .append (NeMoGymResponseOutputMessageForTraining (
249+ output .append (NeMoGymResponseOutputMessageForTraining (
237250 id = f"msg_{ id (msg )} " ,
238251 content = [NeMoGymResponseOutputText (text = content , annotations = [])],
239252 prompt_token_ids = tokens .get ("prompt_ids" , []),
240253 generation_token_ids = tokens .get ("completion_ids" , []),
241254 generation_log_probs = tokens .get ("completion_logprobs" , []),
242- ))
255+ ). model_dump () )
243256 else :
244- step_output .append (NeMoGymResponseOutputMessage (
257+ output .append (NeMoGymResponseOutputMessage (
245258 id = f"msg_{ id (msg )} " ,
246259 content = [NeMoGymResponseOutputText (text = content , annotations = [])],
247- ))
248-
249- output .append (step_output )
260+ ).model_dump ())
250261
251262 return output
252263
253264 async def responses (self , req : VerifiersAgentRunRequest ) -> VerifiersNeMoGymResponse :
254- vf_env_id = req .vf_env_id or self .config .vf_env_id
255- vf_env , env_id , _ = await self ._ensure_env_loaded (vf_env_id )
256-
257- task_idx = req .task_idx
258-
259- prompt_messages = []
260- for item in req .responses_create_params .input or []:
261- if hasattr (item , 'role' ) and hasattr (item , 'content' ):
262- prompt_messages .append ({"role" : item .role , "content" : item .content })
263- elif isinstance (item , dict ):
264- prompt_messages .append ({"role" : item .get ("role" , "user" ), "content" : item .get ("content" , "" )})
265-
266- rollout_input = vf .RolloutInput (
267- prompt = prompt_messages ,
268- answer = req .answer ,
269- task = req .task ,
270- info = req .info ,
271- example_id = req .example_id ,
272- )
273-
274- client = self ._get_openai_client ()
265+ try :
266+ vf_env_id = req .vf_env_id or self .config .vf_env_id
267+ vf_env , env_id , _ = await self ._ensure_env_loaded (vf_env_id )
268+
269+ task_idx = req .task_idx
270+
271+ prompt_messages = []
272+ for item in req .responses_create_params .input or []:
273+ if hasattr (item , 'role' ) and hasattr (item , 'content' ):
274+ prompt_messages .append ({"role" : item .role , "content" : item .content })
275+ elif isinstance (item , dict ):
276+ prompt_messages .append ({"role" : item .get ("role" , "user" ), "content" : item .get ("content" , "" )})
277+
278+ rollout_input = vf .RolloutInput (
279+ prompt = prompt_messages ,
280+ answer = req .answer ,
281+ task = req .task ,
282+ info = req .info ,
283+ example_id = req .example_id ,
284+ )
275285
276- gen_sem = await maybe_semaphore (self .config .max_concurrent_generation )
277- score_sem = await maybe_semaphore (self .config .max_concurrent_scoring )
286+ client = self ._get_openai_client ()
278287
279- sampling_args = {
280- "max_tokens" : self .config .max_tokens ,
281- "temperature" : self .config .temperature ,
282- }
288+ gen_sem = await maybe_semaphore (self .config .max_concurrent_generation )
289+ score_sem = await maybe_semaphore (self .config .max_concurrent_scoring )
283290
284- states = await vf_env .run_group (
285- group_inputs = [rollout_input ],
286- client = client ,
287- model = self .config .model_name ,
288- gen_sampling_args = sampling_args ,
289- gen_sem = gen_sem ,
290- score_sem = score_sem ,
291- )
291+ sampling_args = {
292+ "max_tokens" : self .config .max_tokens ,
293+ "temperature" : self .config .temperature ,
294+ "top_p" : self .config .top_p ,
295+ }
296+ states = await vf_env .run_group (
297+ group_inputs = [rollout_input ],
298+ client = client ,
299+ model = self .config .model_name ,
300+ gen_sampling_args = sampling_args ,
301+ gen_sem = gen_sem ,
302+ score_sem = score_sem ,
303+ )
292304
293- state = states [0 ]
294- reward = state .get ("reward" , 0.0 ) or 0.0
295- metrics = state .get ("metrics" , {}) or {}
296-
297- output = self ._convert_trajectory_to_output (state )
298-
299- return VerifiersNeMoGymResponse (
300- id = f"verifiers-{ env_id } -{ task_idx } " ,
301- created_at = 0 ,
302- model = self .config .model_name ,
303- object = "response" ,
304- output = output ,
305- env_id = env_id ,
306- group_id = str (task_idx ),
307- reward = reward ,
308- metrics = metrics ,
309- )
305+ state = states [0 ]
306+ reward = state .get ("reward" , 0.0 ) or 0.0
307+ metrics = state .get ("metrics" , {}) or {}
308+
309+ output = self ._convert_trajectory_to_output (state )
310+
311+ return VerifiersNeMoGymResponse (
312+ id = f"verifiers-{ env_id } -{ task_idx } " ,
313+ created_at = 0 ,
314+ model = self .config .model_name ,
315+ object = "response" ,
316+ output = output ,
317+ env_id = env_id ,
318+ group_id = str (task_idx ),
319+ reward = reward ,
320+ metrics = metrics ,
321+ )
322+ except Exception as e :
323+ logger .error (f"[verifiers_agent] EXCEPTION in responses(): { type (e ).__name__ } : { e } " )
324+ logger .error (f"[verifiers_agent] Traceback:\n { traceback .format_exc ()} " )
325+ raise
310326
311327 async def run (self , body : VerifiersAgentRunRequest ) -> VerifiersAgentVerifyResponse :
312328 response = await self .responses (body )
313-
314329 return VerifiersAgentVerifyResponse (
315330 responses_create_params = body .responses_create_params ,
316331 response = response ,
0 commit comments