Skip to content

Commit 87f9fc6

Browse files
added policy visualisation from a dataset (#15)
1 parent b5b22d1 commit 87f9fc6

File tree

1 file changed

+274
-0
lines changed

1 file changed

+274
-0
lines changed
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
#!/usr/bin/env python3
2+
"""Simple policy visualization from dataset - single script, no classes."""
3+
4+
import argparse
5+
import random
6+
import sys
7+
import time
8+
from pathlib import Path
9+
10+
import neuracore as nc
11+
import numpy as np
12+
import viser
13+
import yourdfpy
14+
from neuracore_types import (
15+
BatchedJointData,
16+
BatchedNCData,
17+
BatchedParallelGripperOpenAmountData,
18+
DataType,
19+
RobotDataSpec,
20+
)
21+
from PIL import Image
22+
from viser.extras import ViserUrdf
23+
24+
# Add parent directory to path
25+
sys.path.insert(0, str(Path(__file__).parent.parent))
26+
27+
from common.configs import (
28+
CAMERA_LOGGING_NAME,
29+
GRIPPER_LOGGING_NAME,
30+
JOINT_NAMES,
31+
URDF_PATH,
32+
)
33+
34+
# Parse arguments
35+
parser = argparse.ArgumentParser(
36+
description="Visualize policy predictions from dataset"
37+
)
38+
parser.add_argument("--dataset-name", type=str, required=True, help="Dataset name")
39+
parser.add_argument(
40+
"--train-run-name", type=str, default=None, help="Training run name"
41+
)
42+
parser.add_argument("--model-path", type=str, default=None, help="Model file path")
43+
args = parser.parse_args()
44+
45+
if (args.train_run_name is None) == (args.model_path is None):
46+
parser.error("Exactly one of --train-run-name or --model-path must be provided")
47+
48+
# Connect to NeuraCore
49+
print("🔧 Initializing NeuraCore...")
50+
nc.login()
51+
nc.connect_robot(robot_name="AgileX PiPER", urdf_path=str(URDF_PATH), overwrite=False)
52+
53+
# Load policy
54+
model_input_order = {
55+
DataType.JOINT_POSITIONS: JOINT_NAMES,
56+
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
57+
DataType.RGB_IMAGES: [CAMERA_LOGGING_NAME],
58+
}
59+
model_output_order = {
60+
DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES,
61+
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
62+
}
63+
64+
if args.train_run_name:
65+
print(f"🤖 Loading policy from training run: {args.train_run_name}...")
66+
policy = nc.policy(
67+
train_run_name=args.train_run_name,
68+
device="cuda",
69+
model_input_order=model_input_order,
70+
model_output_order=model_output_order,
71+
)
72+
else:
73+
print(f"🤖 Loading policy from model file: {args.model_path}...")
74+
policy = nc.policy(
75+
model_file=args.model_path,
76+
device="cuda",
77+
model_input_order=model_input_order,
78+
model_output_order=model_output_order,
79+
)
80+
print(" ✓ Policy loaded")
81+
82+
# Load and synchronize dataset
83+
print(f"🔍 Loading dataset: {args.dataset_name}...")
84+
dataset = nc.get_dataset(args.dataset_name)
85+
print(f" ✓ Dataset loaded: {len(dataset)} episodes")
86+
87+
robot_data_spec: RobotDataSpec = {
88+
robot_id: dataset.get_full_data_spec(robot_id) for robot_id in dataset.robot_ids
89+
}
90+
91+
print("🔁 Synchronizing dataset...")
92+
synced_dataset = dataset.synchronize(
93+
frequency=100,
94+
robot_data_spec=robot_data_spec,
95+
prefetch_videos=True,
96+
max_prefetch_workers=2,
97+
)
98+
print(f" ✓ Dataset synchronized: {len(synced_dataset)} episodes")
99+
100+
# Setup Viser
101+
print("🖥️ Starting Viser...")
102+
server = viser.ViserServer()
103+
server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
104+
105+
# Load URDF
106+
urdf = yourdfpy.URDF.load(str(URDF_PATH))
107+
urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
108+
urdf_vis.update_cfg(np.zeros(len(JOINT_NAMES)))
109+
110+
# State variables
111+
current_horizon = None
112+
current_action_idx = 0
113+
playing = False
114+
115+
116+
def convert_predictions_to_horizon(
117+
predictions: dict[DataType, dict[str, BatchedNCData]],
118+
) -> dict[str, list[float]]:
119+
"""Convert predictions to horizon dict."""
120+
horizon = {}
121+
if DataType.JOINT_TARGET_POSITIONS in predictions:
122+
joint_data = predictions[DataType.JOINT_TARGET_POSITIONS]
123+
for joint_name in JOINT_NAMES:
124+
if joint_name in joint_data:
125+
batched = joint_data[joint_name]
126+
if isinstance(batched, BatchedJointData):
127+
values = batched.value[0, :, 0].cpu().numpy().tolist()
128+
horizon[joint_name] = values
129+
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions:
130+
gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
131+
if GRIPPER_LOGGING_NAME in gripper_data:
132+
batched = gripper_data[GRIPPER_LOGGING_NAME]
133+
if isinstance(batched, BatchedParallelGripperOpenAmountData):
134+
values = batched.open_amount[0, :, 0].cpu().numpy().tolist()
135+
horizon[GRIPPER_LOGGING_NAME] = values
136+
return horizon
137+
138+
139+
def select_random_state() -> None:
140+
"""Select random state and run policy."""
141+
global current_horizon, current_action_idx, playing
142+
143+
# Select random episode and step
144+
episode_idx = random.randint(0, len(synced_dataset) - 1)
145+
episode = synced_dataset[episode_idx]
146+
if len(episode) == 0:
147+
print(f"⚠️ Episode {episode_idx} is empty")
148+
return
149+
150+
step_idx = random.randint(0, len(episode) - 1)
151+
step = episode[step_idx]
152+
print(f"📊 Selected episode {episode_idx}, step {step_idx}")
153+
154+
# Extract joint positions
155+
joint_positions_dict = {}
156+
if DataType.JOINT_POSITIONS in step.data:
157+
joint_data = step.data[DataType.JOINT_POSITIONS]
158+
for joint_name in JOINT_NAMES:
159+
if joint_name in joint_data:
160+
joint_positions_dict[joint_name] = joint_data[joint_name].value
161+
162+
# Extract gripper
163+
gripper_value = 1.0
164+
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in step.data:
165+
gripper_data = step.data[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
166+
if GRIPPER_LOGGING_NAME in gripper_data:
167+
gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount
168+
169+
# Extract RGB image
170+
rgb_image = None
171+
if DataType.RGB_IMAGES in step.data:
172+
rgb_data = step.data[DataType.RGB_IMAGES]
173+
if CAMERA_LOGGING_NAME in rgb_data:
174+
rgb_image = np.array(rgb_data[CAMERA_LOGGING_NAME].frame)
175+
176+
if rgb_image is None:
177+
print("⚠️ No RGB image found")
178+
return
179+
180+
# Log to NeuraCore
181+
nc.log_joint_positions(joint_positions_dict)
182+
nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_value)
183+
nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image)
184+
185+
# Get policy prediction
186+
print("🎯 Getting policy prediction...")
187+
predictions = policy.predict(timeout=5)
188+
current_horizon = convert_predictions_to_horizon(predictions)
189+
current_action_idx = 0
190+
playing = True
191+
print("FINISHED PREDICTION")
192+
193+
# Save image to file
194+
image_pil = Image.fromarray(rgb_image)
195+
image_pil.save("current_image.png")
196+
print("💾 Saved image to current_image.png")
197+
198+
# Update robot to initial pose
199+
joint_positions = np.array([joint_positions_dict[jn] for jn in JOINT_NAMES])
200+
urdf_vis.update_cfg(joint_positions)
201+
202+
print(
203+
f"✅ Prediction received: {len(current_horizon.get(JOINT_NAMES[0], []))} actions"
204+
)
205+
206+
207+
# Add button
208+
random_button = server.gui.add_button("Random Selection")
209+
random_button.on_click(lambda _: select_random_state())
210+
211+
# Add gripper value display
212+
gripper_handle = server.gui.add_slider(
213+
"Gripper Open Amount",
214+
min=0.0,
215+
max=1.0,
216+
step=0.01,
217+
initial_value=0.0,
218+
disabled=True, # Read-only
219+
)
220+
221+
# Add frequency control
222+
frequency_handle = server.gui.add_number(
223+
"Visualization Frequency (Hz)",
224+
initial_value=100.0,
225+
min=1.0,
226+
max=500.0,
227+
step=1.0,
228+
)
229+
230+
# Select initial state
231+
select_random_state()
232+
# Main loop
233+
try:
234+
while True:
235+
start_time = time.time()
236+
237+
# Update robot visualization
238+
if (
239+
playing
240+
and current_horizon
241+
and len(current_horizon.get(JOINT_NAMES[0], [])) > 0
242+
):
243+
horizon_length = len(current_horizon[JOINT_NAMES[0]])
244+
if current_action_idx < horizon_length:
245+
# Get current action
246+
joint_config = np.array(
247+
[
248+
current_horizon[joint_name][current_action_idx]
249+
for joint_name in JOINT_NAMES
250+
]
251+
)
252+
urdf_vis.update_cfg(joint_config)
253+
254+
# Update gripper value
255+
gripper_value = current_horizon[GRIPPER_LOGGING_NAME][
256+
current_action_idx
257+
]
258+
gripper_handle.value = round(
259+
gripper_value, 2
260+
) # Round to 2 decimal places
261+
262+
# Advance to next action
263+
current_action_idx = (current_action_idx + 1) % horizon_length
264+
265+
# Sleep to control update rate
266+
elapsed = time.time() - start_time
267+
frequency = max(frequency_handle.value, 0.1) # Avoid division by zero
268+
time.sleep(max(0, 1.0 / frequency - elapsed))
269+
270+
except KeyboardInterrupt:
271+
print("\n👋 Shutting down...")
272+
finally:
273+
policy.disconnect()
274+
nc.logout()

0 commit comments

Comments
 (0)