Skip to content

Commit e51ad14

Browse files
authored
arc-agi resource server (#105)
Signed-off-by: cmunley1 <cmunley@nvidia.com> Signed-off-by: Christian Munley <cmunley@nvidia.com>
1 parent 810b728 commit e51ad14

File tree

12 files changed

+490
-0
lines changed

12 files changed

+490
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# ARC-AGI resources server
2+
3+
launch local vllm server
4+
```bash
5+
vllm serve Qwen/Qwen3-30B-A3B \
6+
--dtype auto \
7+
--tensor-parallel-size 8 \
8+
--gpu-memory-utilization 0.9 \
9+
--enable-auto-tool-choice --tool-call-parser hermes \
10+
--host 0.0.0.0 \
11+
--port 10240
12+
```
13+
14+
Start ARC-AGI environment:
15+
```bash
16+
ng_run "+config_paths=[resources_servers/arc_agi/configs/arc_agi.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]"
17+
```
18+
19+
or ARC-AGI-2 environment:
20+
```bash
21+
ng_run "+config_paths=[resources_servers/arc_agi/configs/arc_agi_2.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]"
22+
```
23+
24+
25+
collect rollouts:
26+
27+
ARC-AGI-1 example rollouts
28+
```bash
29+
ng_collect_rollouts +agent_name=arc_agi_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/example_1.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/example_1_rollouts.jsonl +limit=5 +num_repeats=null +num_samples_in_parallel=null
30+
```
31+
32+
ARC-AGI-2 example rollouts:
33+
```bash
34+
ng_collect_rollouts +agent_name=arc_agi_2_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/example_2.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/example_2_rollouts.jsonl +limit=5 +num_repeats=null +num_samples_in_parallel=null
35+
```
36+
37+
ARC-AGI-1 train set rollouts (400 problems):
38+
```bash
39+
ng_collect_rollouts +agent_name=arc_agi_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_1_training.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_1_training_rollouts.jsonl +limit=null +num_repeats=null +num_samples_in_parallel=null
40+
```
41+
42+
ARC-AGI-1 eval set rollouts (400 problems):
43+
```bash
44+
ng_collect_rollouts +agent_name=arc_agi_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_1_evaluation.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_1_evaluation_rollouts.jsonl +limit=null +num_repeats=null +num_samples_in_parallel=null
45+
```
46+
47+
ARC-AGI-2 train set rollouts (1000 problems):
48+
```bash
49+
ng_collect_rollouts +agent_name=arc_agi_2_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_2_training.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_2_training_rollouts.jsonl +limit=null +num_repeats=null +num_samples_in_parallel=null
50+
```
51+
52+
ARC-AGI-2 eval set rollouts (120 problems):
53+
```bash
54+
ng_collect_rollouts +agent_name=arc_agi_2_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_2_evaluation.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_2_evaluation_rollouts.jsonl +limit=null +num_repeats=null +num_samples_in_parallel=null
55+
```
56+
57+
run tests:
58+
```bash
59+
ng_test +entrypoint=resources_servers/arc_agi
60+
```

resources_servers/arc_agi/app.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import re
17+
from typing import List, Optional
18+
19+
from fastapi import FastAPI
20+
21+
from nemo_gym.base_resources_server import (
22+
BaseResourcesServerConfig,
23+
BaseRunRequest,
24+
BaseVerifyRequest,
25+
BaseVerifyResponse,
26+
SimpleResourcesServer,
27+
)
28+
29+
30+
class ARCAGIResourcesServerConfig(BaseResourcesServerConfig):
31+
pass
32+
33+
34+
class ARCAGIRunRequest(BaseRunRequest):
35+
train: List[dict] = []
36+
test_input: List[List[int]] = []
37+
expected_output: List[List[int]] = []
38+
task_id: Optional[str] = None
39+
40+
41+
class ARCAGIVerifyRequest(ARCAGIRunRequest, BaseVerifyRequest):
42+
pass
43+
44+
45+
class ARCAGIVerifyResponse(BaseVerifyResponse):
46+
expected_output: List[List[int]]
47+
predicted_output: Optional[List[List[int]]] = None
48+
extraction_successful: bool = False
49+
50+
51+
def _extract_assistant_text(body: BaseVerifyRequest) -> str:
52+
texts = []
53+
for output in body.response.output:
54+
if getattr(output, "type", None) == "message" and getattr(output, "role", None) == "assistant":
55+
content = getattr(output, "content", None)
56+
if isinstance(content, list):
57+
for part in content:
58+
text = getattr(part, "text", None)
59+
if isinstance(text, str):
60+
texts.append(text)
61+
elif isinstance(content, str):
62+
texts.append(content)
63+
return "\n".join(texts).strip()
64+
65+
66+
def _parse_grid(text: str) -> Optional[List[List[int]]]:
67+
"""expects format: \\boxed{[[1,2,3],[4,5,6]]}"""
68+
boxed_pattern = r"\\boxed\{(\[\s*\[[\d\s,\[\]]+\]\s*\])\}"
69+
boxed_matches = re.findall(boxed_pattern, text, re.DOTALL)
70+
71+
if not boxed_matches:
72+
boxed_matches = re.findall(r"\[\s*\[[\d\s,\[\]]+\]\s*\]", text, re.DOTALL)
73+
74+
for match in boxed_matches:
75+
try:
76+
cleaned = re.sub(r"\s+", "", match)
77+
grid = json.loads(cleaned)
78+
79+
if (
80+
isinstance(grid, list)
81+
and all(isinstance(row, list) and all(isinstance(cell, int) for cell in row) for row in grid)
82+
and len(grid) > 0
83+
and len(grid[0]) > 0
84+
):
85+
return grid
86+
except (json.JSONDecodeError, IndexError, TypeError):
87+
continue
88+
89+
return None
90+
91+
92+
class ARCAGIResourcesServer(SimpleResourcesServer):
93+
config: ARCAGIResourcesServerConfig
94+
95+
def setup_webserver(self) -> FastAPI:
96+
app = super().setup_webserver()
97+
return app
98+
99+
async def verify(self, body: ARCAGIVerifyRequest) -> ARCAGIVerifyResponse:
100+
assistant_text = _extract_assistant_text(body)
101+
predicted_grid = _parse_grid(assistant_text)
102+
103+
extraction_successful = predicted_grid is not None
104+
reward = 1.0 if extraction_successful and predicted_grid == body.expected_output else 0.0
105+
106+
return ARCAGIVerifyResponse(
107+
**body.model_dump(),
108+
reward=reward,
109+
predicted_output=predicted_grid,
110+
extraction_successful=extraction_successful,
111+
)
112+
113+
114+
if __name__ == "__main__":
115+
ARCAGIResourcesServer.run_webserver()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
arc_agi_resources_server:
2+
resources_servers:
3+
arc_agi:
4+
entrypoint: app.py
5+
domain: knowledge
6+
verified: false
7+
arc_agi_simple_agent:
8+
responses_api_agents:
9+
simple_agent:
10+
entrypoint: app.py
11+
resources_server:
12+
type: resources_servers
13+
name: arc_agi_resources_server
14+
model_server:
15+
type: responses_api_models
16+
name: policy_model
17+
datasets:
18+
- name: example
19+
type: example
20+
jsonl_fpath: resources_servers/arc_agi/data/example.jsonl
21+
- name: training_1
22+
type: validation
23+
jsonl_fpath: resources_servers/arc_agi/data/arc_agi_1_training.jsonl
24+
gitlab_identifier:
25+
dataset_name: arc_agi
26+
version: 0.0.1
27+
artifact_fpath: arc_agi_1_training.jsonl
28+
license: Apache 2.0
29+
- name: evaluation_1
30+
type: validation
31+
jsonl_fpath: resources_servers/arc_agi/data/arc_agi_1_evaluation.jsonl
32+
gitlab_identifier:
33+
dataset_name: arc_agi
34+
version: 0.0.1
35+
artifact_fpath: arc_agi_1_evaluation.jsonl
36+
license: Apache 2.0
37+
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
arc_agi_2:
2+
resources_servers:
3+
arc_agi:
4+
entrypoint: app.py
5+
domain: knowledge
6+
verified: false
7+
arc_agi_2_simple_agent:
8+
responses_api_agents:
9+
simple_agent:
10+
entrypoint: app.py
11+
host: 127.0.0.1
12+
port: 15215
13+
resources_server:
14+
type: resources_servers
15+
name: arc_agi_2
16+
model_server:
17+
type: responses_api_models
18+
name: policy_model
19+
datasets:
20+
- name: example_2
21+
type: example
22+
jsonl_fpath: resources_servers/arc_agi/data/example_2.jsonl
23+
- name: training_2
24+
type: validation
25+
jsonl_fpath: resources_servers/arc_agi/data/arc_agi_2_training.jsonl
26+
- name: evaluation_2
27+
type: validation
28+
jsonl_fpath: resources_servers/arc_agi/data/arc_agi_2_evaluation.jsonl
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import json
17+
from pathlib import Path
18+
19+
20+
def format_grid(grid):
21+
return "\n".join([" ".join(map(str, row)) for row in grid])
22+
23+
24+
def create_arc_prompt(task_data, task_id, version=1):
25+
prompt = f"You are solving ARC-AGI{'-' + str(version) if version != 1 else ''} task {task_id}.\n\n"
26+
prompt += "Here are the training examples that demonstrate the pattern:\n\n"
27+
28+
for i, example in enumerate(task_data["train"]):
29+
prompt += f"Example {i + 1}:\n"
30+
prompt += "Input:\n"
31+
prompt += format_grid(example["input"])
32+
prompt += "\n\nOutput:\n"
33+
prompt += format_grid(example["output"])
34+
prompt += "\n\n"
35+
36+
test_input = task_data["test"][0]["input"]
37+
prompt += "Now solve this test case following the same pattern:\n"
38+
prompt += "Test Input:\n"
39+
prompt += format_grid(test_input)
40+
prompt += (
41+
"\n\nProvide your solution as a 2D array inside \\boxed{} in this exact format: \\boxed{[[row1],[row2],...]}"
42+
)
43+
prompt += "\nFor example: \\boxed{[[1,2,3],[4,5,6],[7,8,9]]}"
44+
45+
return prompt
46+
47+
48+
def create_dataset(version=1):
49+
data_base = f"../../ARC-AGI{'-' + str(version) if version != 1 else ''}"
50+
training_dir = Path(f"{data_base}/data/training")
51+
evaluation_dir = Path(f"{data_base}/data/evaluation")
52+
53+
Path("data").mkdir(exist_ok=True)
54+
55+
training_dataset = []
56+
print(f"Processing {len(list(training_dir.glob('*.json')))} training tasks...") # 400 tasks
57+
58+
for task_file in sorted(training_dir.glob("*.json")):
59+
task_id = task_file.stem
60+
61+
with open(task_file) as f:
62+
task_data = json.load(f)
63+
64+
prompt = create_arc_prompt(task_data, task_id, version)
65+
expected_output = task_data["test"][0]["output"]
66+
test_input = task_data["test"][0]["input"]
67+
68+
entry = {
69+
"responses_create_params": {"input": [{"role": "user", "content": prompt}]},
70+
"train": task_data["train"],
71+
"test_input": test_input,
72+
"expected_output": expected_output,
73+
"task_id": task_id,
74+
}
75+
76+
training_dataset.append(entry)
77+
78+
training_output_file = Path(f"data/arc_agi_{version}_training.jsonl")
79+
with open(training_output_file, "w") as f:
80+
for entry in training_dataset:
81+
f.write(json.dumps(entry) + "\n")
82+
83+
print(f"Created training dataset with {len(training_dataset)} tasks at {training_output_file}")
84+
85+
evaluation_dataset = []
86+
print(f"Processing {len(list(evaluation_dir.glob('*.json')))} evaluation tasks...") # 400 tasks
87+
88+
for task_file in sorted(evaluation_dir.glob("*.json")):
89+
task_id = task_file.stem
90+
91+
with open(task_file) as f:
92+
task_data = json.load(f)
93+
94+
prompt = create_arc_prompt(task_data, task_id, version)
95+
expected_output = task_data["test"][0]["output"]
96+
test_input = task_data["test"][0]["input"]
97+
98+
entry = {
99+
"responses_create_params": {"input": [{"role": "user", "content": prompt}]},
100+
"train": task_data["train"],
101+
"test_input": test_input,
102+
"expected_output": expected_output,
103+
"task_id": task_id,
104+
}
105+
106+
evaluation_dataset.append(entry)
107+
108+
evaluation_output_file = Path(f"data/arc_agi_{version}_evaluation.jsonl")
109+
with open(evaluation_output_file, "w") as f:
110+
for entry in evaluation_dataset:
111+
f.write(json.dumps(entry) + "\n")
112+
113+
print(f"Created evaluation dataset with {len(evaluation_dataset)} tasks at {evaluation_output_file}")
114+
115+
example_output_file = Path(f"data/example_{version}.jsonl")
116+
with open(example_output_file, "w") as f:
117+
for entry in evaluation_dataset[:5]:
118+
f.write(json.dumps(entry) + "\n")
119+
120+
print(f"Created example dataset with 5 tasks at {example_output_file}")
121+
122+
123+
if __name__ == "__main__":
124+
parser = argparse.ArgumentParser(description="Create ARC-AGI dataset")
125+
parser.add_argument("--version", type=int, default=1, choices=[1, 2], help="ARC-AGI version (1 or 2)")
126+
args = parser.parse_args()
127+
128+
create_dataset(version=args.version)

0 commit comments

Comments
 (0)