Skip to content

Commit b559bb7

Browse files
committed
training
Signed-off-by: cmunley1 <cmunley@nvidia.com>
1 parent 28d273c commit b559bb7

File tree

4 files changed

+92
-73
lines changed

4 files changed

+92
-73
lines changed

resources_servers/verifiers/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ wordle
77
aime2025
88
ifeval
99
alphabet-sort
10+
i3-math
11+
acereason-math

resources_servers/verifiers/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class VerifiersNeMoGymResponse(NeMoGymResponse):
9191
env_id: str
9292
group_id: str
9393
contains_transitions: Literal[True] = True
94-
output: list[list[NeMoGymResponseOutputItem]]
94+
output: list[dict[str, Any]]
9595
reward: float
9696
metrics: dict[str, Any] = Field(default_factory=dict)
9797
parallel_tool_calls: bool = False

responses_api_agents/verifiers_agent/app.py

Lines changed: 87 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import sys
16+
import traceback
1517
from typing import Any
1618

1719
import aiohttp
@@ -40,7 +42,7 @@
4042

4143

4244
class _VLLMChatCompletions(AsyncCompletions):
43-
"""Wraps vllm_model and injects token IDs as attributes for verifiers."""
45+
"""adapt vllm_model format to verifiers expected format"""
4446
def __init__(self, base_url: str) -> None:
4547
self._base_url = base_url.rstrip("/")
4648

@@ -53,19 +55,33 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion:
5355
if key in kwargs and kwargs[key] is not None:
5456
request_body[key] = kwargs[key]
5557

56-
async with aiohttp.ClientSession() as session:
57-
async with session.post(f"{self._base_url}/chat/completions", json=request_body) as resp:
58-
resp.raise_for_status()
59-
response_dict = await resp.json()
58+
url = f"{self._base_url}/chat/completions"
59+
try:
60+
async with aiohttp.ClientSession() as session:
61+
async with session.post(url, json=request_body) as resp:
62+
if resp.status != 200:
63+
error_text = await resp.text()
64+
logger.error(f"[verifiers_agent] Request to {url} failed with status {resp.status}: {error_text[:500]}")
65+
resp.raise_for_status()
66+
response_dict = await resp.json()
67+
except Exception as e:
68+
logger.error(f"[verifiers_agent] Exception calling {url}: {type(e).__name__}: {e}")
69+
raise
6070

61-
# Extract token IDs from vllm_model
6271
choice_dict = response_dict["choices"][0]
6372
message_dict = choice_dict.get("message", {})
73+
74+
6475
prompt_token_ids = message_dict.pop("prompt_token_ids", [])
6576
generation_token_ids = message_dict.pop("generation_token_ids", [])
6677
generation_log_probs = message_dict.pop("generation_log_probs", [])
6778

68-
# Reconstruct logprobs.content for verifiers
79+
if not generation_token_ids:
80+
logger.warning(f"[verifiers_agent] No generation_token_ids in response! Full message keys were: {list(choice_dict.get('message', {}).keys())}")
81+
82+
if generation_token_ids and isinstance(generation_token_ids[0], str):
83+
generation_token_ids = [int(tid) for tid in generation_token_ids]
84+
6985
if generation_token_ids and generation_log_probs:
7086
choice_dict["logprobs"] = {
7187
"content": [
@@ -116,6 +132,7 @@ class VerifiersAgentConfig(BaseResponsesAPIAgentConfig):
116132

117133
max_tokens: int = Field(default=512, description="Max tokens for generation")
118134
temperature: float = Field(default=1.0, description="Sampling temperature")
135+
top_p: float = Field(default=1.0, description="Top-p sampling")
119136

120137

121138
class VerifiersAgentRunRequest(BaseRunRequest):
@@ -129,7 +146,7 @@ class VerifiersAgentRunRequest(BaseRunRequest):
129146
answer: str = Field(default="", description="Expected answer")
130147
task: str = Field(default="default", description="Task type")
131148
example_id: int | str = Field(default=0, description="Example ID")
132-
info: dict = Field(default_factory=dict, description="Extra info for scoring (e.g., ifeval constraints)")
149+
info: dict = Field(default_factory=dict, description="Extra info for scoring")
133150

134151

135152
_ENVS_CACHE: dict[str, vf.Environment] = {}
@@ -162,7 +179,6 @@ async def _ensure_env_loaded(self, vf_env_id: str) -> tuple[vf.Environment, str,
162179
ds = getattr(vf_env, attr, None)
163180
if ds is not None:
164181
dataset = ds
165-
logger.info(f"Found dataset in vf_env.{attr}")
166182
break
167183
if dataset is None:
168184
raise ValueError(f"Environment {vf_env_id} does not have a dataset")
@@ -203,7 +219,6 @@ def _get_openai_client(self) -> VLLMOpenAIClient:
203219
model_server_url = model_server_url.rstrip("/") + "/v1"
204220

205221
_OPENAI_CLIENT_CACHE[cache_key] = VLLMOpenAIClient(base_url=model_server_url)
206-
logger.info(f"Created VLLMOpenAIClient pointing to: {model_server_url}")
207222

208223
return _OPENAI_CLIENT_CACHE[cache_key]
209224

@@ -220,97 +235,97 @@ def _convert_trajectory_to_output(self, state: dict) -> list:
220235
trajectory = state.get("trajectory", [])
221236

222237
for step in trajectory:
223-
step_output = []
224-
225238
for msg in step.get("prompt", []):
226239
if isinstance(msg, dict):
227240
role = msg.get("role", "user")
228241
content = msg.get("content", "")
229-
step_output.append(NeMoGymEasyInputMessage(role=role, content=content))
242+
output.append(NeMoGymEasyInputMessage(role=role, content=content).model_dump())
230243

231244
tokens = step.get("tokens")
232245
for msg in step.get("completion", []):
233246
if isinstance(msg, dict):
234247
content = msg.get("content", "")
235248
if tokens:
236-
step_output.append(NeMoGymResponseOutputMessageForTraining(
249+
output.append(NeMoGymResponseOutputMessageForTraining(
237250
id=f"msg_{id(msg)}",
238251
content=[NeMoGymResponseOutputText(text=content, annotations=[])],
239252
prompt_token_ids=tokens.get("prompt_ids", []),
240253
generation_token_ids=tokens.get("completion_ids", []),
241254
generation_log_probs=tokens.get("completion_logprobs", []),
242-
))
255+
).model_dump())
243256
else:
244-
step_output.append(NeMoGymResponseOutputMessage(
257+
output.append(NeMoGymResponseOutputMessage(
245258
id=f"msg_{id(msg)}",
246259
content=[NeMoGymResponseOutputText(text=content, annotations=[])],
247-
))
248-
249-
output.append(step_output)
260+
).model_dump())
250261

251262
return output
252263

253264
async def responses(self, req: VerifiersAgentRunRequest) -> VerifiersNeMoGymResponse:
254-
vf_env_id = req.vf_env_id or self.config.vf_env_id
255-
vf_env, env_id, _ = await self._ensure_env_loaded(vf_env_id)
256-
257-
task_idx = req.task_idx
258-
259-
prompt_messages = []
260-
for item in req.responses_create_params.input or []:
261-
if hasattr(item, 'role') and hasattr(item, 'content'):
262-
prompt_messages.append({"role": item.role, "content": item.content})
263-
elif isinstance(item, dict):
264-
prompt_messages.append({"role": item.get("role", "user"), "content": item.get("content", "")})
265-
266-
rollout_input = vf.RolloutInput(
267-
prompt=prompt_messages,
268-
answer=req.answer,
269-
task=req.task,
270-
info=req.info,
271-
example_id=req.example_id,
272-
)
273-
274-
client = self._get_openai_client()
265+
try:
266+
vf_env_id = req.vf_env_id or self.config.vf_env_id
267+
vf_env, env_id, _ = await self._ensure_env_loaded(vf_env_id)
268+
269+
task_idx = req.task_idx
270+
271+
prompt_messages = []
272+
for item in req.responses_create_params.input or []:
273+
if hasattr(item, 'role') and hasattr(item, 'content'):
274+
prompt_messages.append({"role": item.role, "content": item.content})
275+
elif isinstance(item, dict):
276+
prompt_messages.append({"role": item.get("role", "user"), "content": item.get("content", "")})
277+
278+
rollout_input = vf.RolloutInput(
279+
prompt=prompt_messages,
280+
answer=req.answer,
281+
task=req.task,
282+
info=req.info,
283+
example_id=req.example_id,
284+
)
275285

276-
gen_sem = await maybe_semaphore(self.config.max_concurrent_generation)
277-
score_sem = await maybe_semaphore(self.config.max_concurrent_scoring)
286+
client = self._get_openai_client()
278287

279-
sampling_args = {
280-
"max_tokens": self.config.max_tokens,
281-
"temperature": self.config.temperature,
282-
}
288+
gen_sem = await maybe_semaphore(self.config.max_concurrent_generation)
289+
score_sem = await maybe_semaphore(self.config.max_concurrent_scoring)
283290

284-
states = await vf_env.run_group(
285-
group_inputs=[rollout_input],
286-
client=client,
287-
model=self.config.model_name,
288-
gen_sampling_args=sampling_args,
289-
gen_sem=gen_sem,
290-
score_sem=score_sem,
291-
)
291+
sampling_args = {
292+
"max_tokens": self.config.max_tokens,
293+
"temperature": self.config.temperature,
294+
"top_p": self.config.top_p,
295+
}
296+
states = await vf_env.run_group(
297+
group_inputs=[rollout_input],
298+
client=client,
299+
model=self.config.model_name,
300+
gen_sampling_args=sampling_args,
301+
gen_sem=gen_sem,
302+
score_sem=score_sem,
303+
)
292304

293-
state = states[0]
294-
reward = state.get("reward", 0.0) or 0.0
295-
metrics = state.get("metrics", {}) or {}
296-
297-
output = self._convert_trajectory_to_output(state)
298-
299-
return VerifiersNeMoGymResponse(
300-
id=f"verifiers-{env_id}-{task_idx}",
301-
created_at=0,
302-
model=self.config.model_name,
303-
object="response",
304-
output=output,
305-
env_id=env_id,
306-
group_id=str(task_idx),
307-
reward=reward,
308-
metrics=metrics,
309-
)
305+
state = states[0]
306+
reward = state.get("reward", 0.0) or 0.0
307+
metrics = state.get("metrics", {}) or {}
308+
309+
output = self._convert_trajectory_to_output(state)
310+
311+
return VerifiersNeMoGymResponse(
312+
id=f"verifiers-{env_id}-{task_idx}",
313+
created_at=0,
314+
model=self.config.model_name,
315+
object="response",
316+
output=output,
317+
env_id=env_id,
318+
group_id=str(task_idx),
319+
reward=reward,
320+
metrics=metrics,
321+
)
322+
except Exception as e:
323+
logger.error(f"[verifiers_agent] EXCEPTION in responses(): {type(e).__name__}: {e}")
324+
logger.error(f"[verifiers_agent] Traceback:\n{traceback.format_exc()}")
325+
raise
310326

311327
async def run(self, body: VerifiersAgentRunRequest) -> VerifiersAgentVerifyResponse:
312328
response = await self.responses(body)
313-
314329
return VerifiersAgentVerifyResponse(
315330
responses_create_params=body.responses_create_params,
316331
response=response,

responses_api_agents/verifiers_agent/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ wordle
77
aime2025
88
ifeval
99
alphabet-sort
10+
i3-math
11+
acereason-math

0 commit comments

Comments
 (0)