Skip to content

Commit d68dabc

Browse files
Terminus (judge only) Slicing Environment (#594)
Refactoring the equivalency llm judge resource server into another judge-based resource server. Main changes include removing regex logic and cleaning up related configs to that. Train data for this environment is still TBD, but a working version: Data source: Sliced terminus prompts from different sources train_jsonl_fpath: `/lustre/fsw/portfolios/llmservice/users/kbhardwaj/dev/my-envs/terminus-sliced/char/nano3-ga-traindata-char-tokenlen-32768.jsonl` validation_jsonl_fpath: `/lustre/fsw/portfolios/llmservice/users/kbhardwaj/dev/my-envs/terminus-sliced/char/nano3-ga-valdata-char-tokenlen-16384.jsonl` example train config: `/lustre/fsw/portfolios/llmservice/users/kbhardwaj/dev/nemo-rl-internal-yifu/training_configs/grpo_nanov3-nickel-capybara-4-nodes-judge-roff-512-49k-seq-reasoning-off-char-data-64x16-temp1-iter-1600.yaml` Example of env validation: base model: early sft checkpoint of nano v3 (`nano-v3-sft-64gbs-nickel-capybara-5e-5-constant-wd-0-load-bal-1e-4-lcx3-pretool-base-temp1-iter-0013600-hf`) Step 50 -> 21.25% on Terminal Bench Core https://wandb.ai/nvidia/terminus-sliced/runs/rs7c40hi Next steps: Will expand this PR with configurable verification options including string matching, string similarity and openapi-based output schema validation. --------- Signed-off-by: Khushi Bhardwaj <kbhardwaj@nvidia.com>
1 parent a5cd5eb commit d68dabc

File tree

10 files changed

+778
-0
lines changed

10 files changed

+778
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Description
2+
3+
Data links: ?
4+
5+
# Licensing information
6+
Code: ?
7+
Data: ?
8+
9+
Dependencies
10+
- nemo_gym: Apache 2.0
11+
?
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import asyncio
16+
from contextlib import nullcontext
17+
from typing import Any, Optional
18+
19+
from fastapi import FastAPI
20+
from pydantic import BaseModel, ConfigDict
21+
22+
from nemo_gym.base_resources_server import (
23+
BaseResourcesServerConfig,
24+
BaseRunRequest,
25+
BaseVerifyRequest,
26+
BaseVerifyResponse,
27+
SimpleResourcesServer,
28+
)
29+
from nemo_gym.config_types import ModelServerRef
30+
from nemo_gym.openai_utils import (
31+
NeMoGymEasyInputMessage,
32+
NeMoGymResponse,
33+
NeMoGymResponseCreateParamsNonStreaming,
34+
)
35+
36+
37+
class TerminusJudgeResourcesServerConfig(BaseResourcesServerConfig):
38+
name: str = "terminus_judge"
39+
judge_model_server: ModelServerRef
40+
judge_responses_create_params: NeMoGymResponseCreateParamsNonStreaming
41+
judge_endpoint_max_concurrency: Optional[int] = 64
42+
judge_system_message: Optional[str] = None
43+
judge_prompt_template_fpath: str = "prompt_templates/terminus_judge.txt"
44+
judge_equal_label: str = "[[A=B]]"
45+
judge_not_equal_label: str = "[[A!=B]]"
46+
check_twice_swap: bool = False
47+
reward_if_swap_fails: float = 0.0
48+
49+
50+
class TerminusJudgeRunRequest(BaseRunRequest):
51+
"""Run/verify request payload."""
52+
53+
model_config = ConfigDict(extra="allow")
54+
55+
uuid: Optional[str | int] = None
56+
expected_answer: Optional[str] = None
57+
options: Optional[list[dict[str, str]]] = None
58+
metadata: Optional[dict[str, Any]] = None
59+
60+
61+
class TerminusJudgeVerifyRequest(TerminusJudgeRunRequest, BaseVerifyRequest):
62+
pass
63+
64+
65+
class JudgeEvaluation(BaseModel):
66+
responses_create_params: NeMoGymResponseCreateParamsNonStreaming
67+
response: NeMoGymResponse
68+
verdict_label: Optional[str] = None
69+
70+
71+
class TerminusJudgeVerifyResponse(BaseVerifyResponse):
72+
expected_answer: str
73+
judge_evaluations: list[JudgeEvaluation]
74+
75+
76+
def _extract_last_assistant_text(body: BaseVerifyRequest) -> str:
77+
"""Extract the last assistant message text from the response.
78+
79+
Returns an empty string when no assistant text is available.
80+
"""
81+
for o in reversed(body.response.output):
82+
if getattr(o, "type", None) == "message" and getattr(o, "role", None) == "assistant":
83+
content = getattr(o, "content", None)
84+
if isinstance(content, list):
85+
texts: list[str] = []
86+
for c in content:
87+
t = getattr(c, "text", None)
88+
if isinstance(t, str):
89+
texts.append(t)
90+
return "\n".join(texts).strip()
91+
elif isinstance(content, str):
92+
return content.strip()
93+
break
94+
return ""
95+
96+
97+
def _extract_expected_answer(req: TerminusJudgeRunRequest) -> Optional[str]:
98+
"""Extract expected answer from request."""
99+
if req.expected_answer:
100+
return str(req.expected_answer)
101+
md = req.metadata or {}
102+
exp = md.get("expected_answer")
103+
return str(exp) if exp is not None else None
104+
105+
106+
class TerminusJudgeResourcesServer(SimpleResourcesServer):
107+
config: TerminusJudgeResourcesServerConfig
108+
109+
def __init__(self, *args, **kwargs):
110+
super().__init__(*args, **kwargs)
111+
112+
if self.config.judge_endpoint_max_concurrency is not None:
113+
self._judge_endpoint_max_concurrency = asyncio.Semaphore(value=self.config.judge_endpoint_max_concurrency)
114+
else:
115+
self._judge_endpoint_max_concurrency = None
116+
117+
with open(self.config.judge_prompt_template_fpath, "r") as f:
118+
self._judge_prompt_template = f.read().strip()
119+
120+
def setup_webserver(self) -> FastAPI:
121+
app = super().setup_webserver()
122+
123+
return app
124+
125+
async def verify(self, body: TerminusJudgeVerifyRequest) -> TerminusJudgeVerifyResponse:
126+
expected = _extract_expected_answer(body)
127+
if not expected:
128+
raise ValueError("Expected answer is required but was not provided")
129+
130+
generated = _extract_last_assistant_text(body)
131+
if not generated:
132+
raise ValueError("No assistant response found/extracted to verify")
133+
# Run first judge evaluation
134+
first_equal, first_eval = await self._generate_judge_evaluation(
135+
expected_answer=expected, generated_answer=generated
136+
)
137+
138+
evaluations = [first_eval]
139+
140+
# Handle swap check if configured
141+
if first_equal and self.config.check_twice_swap:
142+
second_equal, second_eval = await self._generate_judge_evaluation(
143+
expected_answer=generated, generated_answer=expected
144+
)
145+
evaluations.append(second_eval)
146+
reward = 1.0 if second_equal else self.config.reward_if_swap_fails
147+
else:
148+
reward = 1.0 if first_equal else 0.0
149+
150+
payload = body.model_dump()
151+
payload.pop("expected_answer", None)
152+
153+
return TerminusJudgeVerifyResponse(
154+
**payload,
155+
reward=reward,
156+
expected_answer=expected,
157+
judge_evaluations=evaluations,
158+
)
159+
160+
async def _generate_judge_evaluation(
161+
self, *, expected_answer: str, generated_answer: str
162+
) -> tuple[bool, JudgeEvaluation]:
163+
"""Run a single judge evaluation."""
164+
cfg = self.config
165+
equal_label = cfg.judge_equal_label
166+
not_equal_label = cfg.judge_not_equal_label
167+
168+
responses_create_params = cfg.judge_responses_create_params.model_copy(deep=True)
169+
170+
user_prompt = self._judge_prompt_template.format(
171+
expected_answer=expected_answer, generated_answer=generated_answer
172+
)
173+
174+
msgs: list[NeMoGymEasyInputMessage] = []
175+
if cfg.judge_system_message:
176+
msgs.append(NeMoGymEasyInputMessage(role="system", content=cfg.judge_system_message))
177+
msgs.append(NeMoGymEasyInputMessage(role="user", content=user_prompt))
178+
responses_create_params.input = msgs
179+
180+
ctx = self._judge_endpoint_max_concurrency or nullcontext()
181+
async with ctx:
182+
try:
183+
response = await self.server_client.post(
184+
server_name=cfg.judge_model_server.name,
185+
url_path="/v1/responses",
186+
json=responses_create_params,
187+
)
188+
189+
judge_response = NeMoGymResponse.model_validate(await response.json())
190+
191+
except asyncio.TimeoutError:
192+
print(
193+
"DEBUG: TerminusJudgeResourcesServer: Judge model server timeout",
194+
flush=True,
195+
)
196+
raise RuntimeError("Judge model server timeout")
197+
except Exception as e:
198+
print(
199+
f"DEBUG: TerminusJudgeResourcesServer: judge model server HTTP POST error: {type(e).__name__} {e}",
200+
flush=True,
201+
)
202+
raise e
203+
204+
eval_record = JudgeEvaluation(
205+
responses_create_params=responses_create_params,
206+
response=judge_response,
207+
verdict_label=None,
208+
)
209+
210+
verdict_label = None
211+
is_equal = False
212+
213+
# extract text
214+
try:
215+
last_output = judge_response.output[-1]
216+
if getattr(last_output, "type", None) != "message":
217+
text = ""
218+
else:
219+
last_content = last_output.content[-1]
220+
text = getattr(last_content, "text", "")
221+
except Exception:
222+
text = ""
223+
224+
# check text for verdict labels
225+
if text:
226+
eq_pos = text.find(equal_label)
227+
neq_pos = text.find(not_equal_label)
228+
229+
if eq_pos >= 0 and (neq_pos < 0 or eq_pos < neq_pos):
230+
verdict_label = equal_label
231+
is_equal = True
232+
elif neq_pos >= 0:
233+
verdict_label = not_equal_label
234+
235+
eval_record.verdict_label = verdict_label
236+
return is_equal, eval_record
237+
238+
239+
if __name__ == "__main__":
240+
TerminusJudgeResourcesServer.run_webserver()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
terminus_judge_resources_server:
2+
resources_servers:
3+
terminus_judge:
4+
entrypoint: app.py
5+
judge_model_server:
6+
type: responses_api_models
7+
name: policy_model
8+
judge_responses_create_params:
9+
input: []
10+
judge_prompt_template_fpath: prompt_templates/terminus_prompt.txt
11+
judge_endpoint_max_concurrency: null
12+
judge_system_message: null
13+
judge_equal_label: "[[A=B]]"
14+
judge_not_equal_label: "[[A!=B]]"
15+
16+
# Swap check: Run second judge pass with swapped expected/generated to detect positional bias
17+
check_twice_swap: true
18+
# Reward when the second (swap) pass fails; default 0.0, can be -1.0
19+
reward_if_swap_fails: 0.0
20+
21+
domain: agent
22+
verified: false
23+
description: single-step terminal based task
24+
value: Improve on terminal-style tasks
25+
26+
terminus_judge_simple_agent:
27+
responses_api_agents:
28+
simple_agent:
29+
entrypoint: app.py
30+
resources_server:
31+
type: resources_servers
32+
name: terminus_judge_resources_server
33+
model_server:
34+
type: responses_api_models
35+
name: policy_model
36+
datasets:
37+
- name: train
38+
type: train
39+
jsonl_fpath: resources_servers/terminus_judge/data/train.jsonl
40+
num_repeats: 1
41+
gitlab_identifier:
42+
dataset_name: terminus_judge
43+
version: 0.0.1
44+
artifact_fpath: train.jsonl
45+
license: Apache 2.0
46+
- name: validation
47+
type: validation
48+
jsonl_fpath: resources_servers/terminus_judge/data/validation.jsonl
49+
num_repeats: 1
50+
gitlab_identifier:
51+
dataset_name: terminus_judge
52+
version: 0.0.1
53+
artifact_fpath: validation.jsonl
54+
license: Apache 2.0
55+
- name: example
56+
type: example
57+
jsonl_fpath: resources_servers/terminus_judge/data/example.jsonl
58+
num_repeats: 1
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
*train.jsonl
2+
*validation.jsonl
3+
*train_prepare.jsonl
4+
*validation_prepare.jsonl
5+
*example_prepare.jsonl

0 commit comments

Comments
 (0)