|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Simple policy visualization from dataset - single script, no classes.""" |
| 3 | + |
| 4 | +import argparse |
| 5 | +import random |
| 6 | +import sys |
| 7 | +import time |
| 8 | +from pathlib import Path |
| 9 | + |
| 10 | +import neuracore as nc |
| 11 | +import numpy as np |
| 12 | +import viser |
| 13 | +import yourdfpy |
| 14 | +from neuracore_types import ( |
| 15 | + BatchedJointData, |
| 16 | + BatchedNCData, |
| 17 | + BatchedParallelGripperOpenAmountData, |
| 18 | + DataType, |
| 19 | + RobotDataSpec, |
| 20 | +) |
| 21 | +from PIL import Image |
| 22 | +from viser.extras import ViserUrdf |
| 23 | + |
| 24 | +# Add parent directory to path |
| 25 | +sys.path.insert(0, str(Path(__file__).parent.parent)) |
| 26 | + |
| 27 | +from common.configs import ( |
| 28 | + CAMERA_LOGGING_NAME, |
| 29 | + GRIPPER_LOGGING_NAME, |
| 30 | + JOINT_NAMES, |
| 31 | + URDF_PATH, |
| 32 | +) |
| 33 | + |
| 34 | +# Parse arguments |
| 35 | +parser = argparse.ArgumentParser( |
| 36 | + description="Visualize policy predictions from dataset" |
| 37 | +) |
| 38 | +parser.add_argument("--dataset-name", type=str, required=True, help="Dataset name") |
| 39 | +parser.add_argument( |
| 40 | + "--train-run-name", type=str, default=None, help="Training run name" |
| 41 | +) |
| 42 | +parser.add_argument("--model-path", type=str, default=None, help="Model file path") |
| 43 | +args = parser.parse_args() |
| 44 | + |
| 45 | +if (args.train_run_name is None) == (args.model_path is None): |
| 46 | + parser.error("Exactly one of --train-run-name or --model-path must be provided") |
| 47 | + |
| 48 | +# Connect to NeuraCore |
| 49 | +print("🔧 Initializing NeuraCore...") |
| 50 | +nc.login() |
| 51 | +nc.connect_robot(robot_name="AgileX PiPER", urdf_path=str(URDF_PATH), overwrite=False) |
| 52 | + |
| 53 | +# Load policy |
| 54 | +model_input_order = { |
| 55 | + DataType.JOINT_POSITIONS: JOINT_NAMES, |
| 56 | + DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], |
| 57 | + DataType.RGB_IMAGES: [CAMERA_LOGGING_NAME], |
| 58 | +} |
| 59 | +model_output_order = { |
| 60 | + DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES, |
| 61 | + DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], |
| 62 | +} |
| 63 | + |
| 64 | +if args.train_run_name: |
| 65 | + print(f"🤖 Loading policy from training run: {args.train_run_name}...") |
| 66 | + policy = nc.policy( |
| 67 | + train_run_name=args.train_run_name, |
| 68 | + device="cuda", |
| 69 | + model_input_order=model_input_order, |
| 70 | + model_output_order=model_output_order, |
| 71 | + ) |
| 72 | +else: |
| 73 | + print(f"🤖 Loading policy from model file: {args.model_path}...") |
| 74 | + policy = nc.policy( |
| 75 | + model_file=args.model_path, |
| 76 | + device="cuda", |
| 77 | + model_input_order=model_input_order, |
| 78 | + model_output_order=model_output_order, |
| 79 | + ) |
| 80 | +print(" ✓ Policy loaded") |
| 81 | + |
| 82 | +# Load and synchronize dataset |
| 83 | +print(f"🔍 Loading dataset: {args.dataset_name}...") |
| 84 | +dataset = nc.get_dataset(args.dataset_name) |
| 85 | +print(f" ✓ Dataset loaded: {len(dataset)} episodes") |
| 86 | + |
| 87 | +robot_data_spec: RobotDataSpec = { |
| 88 | + robot_id: dataset.get_full_data_spec(robot_id) for robot_id in dataset.robot_ids |
| 89 | +} |
| 90 | + |
| 91 | +print("🔁 Synchronizing dataset...") |
| 92 | +synced_dataset = dataset.synchronize( |
| 93 | + frequency=100, |
| 94 | + robot_data_spec=robot_data_spec, |
| 95 | + prefetch_videos=True, |
| 96 | + max_prefetch_workers=2, |
| 97 | +) |
| 98 | +print(f" ✓ Dataset synchronized: {len(synced_dataset)} episodes") |
| 99 | + |
| 100 | +# Setup Viser |
| 101 | +print("🖥️ Starting Viser...") |
| 102 | +server = viser.ViserServer() |
| 103 | +server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) |
| 104 | + |
| 105 | +# Load URDF |
| 106 | +urdf = yourdfpy.URDF.load(str(URDF_PATH)) |
| 107 | +urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") |
| 108 | +urdf_vis.update_cfg(np.zeros(len(JOINT_NAMES))) |
| 109 | + |
| 110 | +# State variables |
| 111 | +current_horizon = None |
| 112 | +current_action_idx = 0 |
| 113 | +playing = False |
| 114 | + |
| 115 | + |
| 116 | +def convert_predictions_to_horizon( |
| 117 | + predictions: dict[DataType, dict[str, BatchedNCData]], |
| 118 | +) -> dict[str, list[float]]: |
| 119 | + """Convert predictions to horizon dict.""" |
| 120 | + horizon = {} |
| 121 | + if DataType.JOINT_TARGET_POSITIONS in predictions: |
| 122 | + joint_data = predictions[DataType.JOINT_TARGET_POSITIONS] |
| 123 | + for joint_name in JOINT_NAMES: |
| 124 | + if joint_name in joint_data: |
| 125 | + batched = joint_data[joint_name] |
| 126 | + if isinstance(batched, BatchedJointData): |
| 127 | + values = batched.value[0, :, 0].cpu().numpy().tolist() |
| 128 | + horizon[joint_name] = values |
| 129 | + if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions: |
| 130 | + gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] |
| 131 | + if GRIPPER_LOGGING_NAME in gripper_data: |
| 132 | + batched = gripper_data[GRIPPER_LOGGING_NAME] |
| 133 | + if isinstance(batched, BatchedParallelGripperOpenAmountData): |
| 134 | + values = batched.open_amount[0, :, 0].cpu().numpy().tolist() |
| 135 | + horizon[GRIPPER_LOGGING_NAME] = values |
| 136 | + return horizon |
| 137 | + |
| 138 | + |
| 139 | +def select_random_state() -> None: |
| 140 | + """Select random state and run policy.""" |
| 141 | + global current_horizon, current_action_idx, playing |
| 142 | + |
| 143 | + # Select random episode and step |
| 144 | + episode_idx = random.randint(0, len(synced_dataset) - 1) |
| 145 | + episode = synced_dataset[episode_idx] |
| 146 | + if len(episode) == 0: |
| 147 | + print(f"⚠️ Episode {episode_idx} is empty") |
| 148 | + return |
| 149 | + |
| 150 | + step_idx = random.randint(0, len(episode) - 1) |
| 151 | + step = episode[step_idx] |
| 152 | + print(f"📊 Selected episode {episode_idx}, step {step_idx}") |
| 153 | + |
| 154 | + # Extract joint positions |
| 155 | + joint_positions_dict = {} |
| 156 | + if DataType.JOINT_POSITIONS in step.data: |
| 157 | + joint_data = step.data[DataType.JOINT_POSITIONS] |
| 158 | + for joint_name in JOINT_NAMES: |
| 159 | + if joint_name in joint_data: |
| 160 | + joint_positions_dict[joint_name] = joint_data[joint_name].value |
| 161 | + |
| 162 | + # Extract gripper |
| 163 | + gripper_value = 1.0 |
| 164 | + if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in step.data: |
| 165 | + gripper_data = step.data[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] |
| 166 | + if GRIPPER_LOGGING_NAME in gripper_data: |
| 167 | + gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount |
| 168 | + |
| 169 | + # Extract RGB image |
| 170 | + rgb_image = None |
| 171 | + if DataType.RGB_IMAGES in step.data: |
| 172 | + rgb_data = step.data[DataType.RGB_IMAGES] |
| 173 | + if CAMERA_LOGGING_NAME in rgb_data: |
| 174 | + rgb_image = np.array(rgb_data[CAMERA_LOGGING_NAME].frame) |
| 175 | + |
| 176 | + if rgb_image is None: |
| 177 | + print("⚠️ No RGB image found") |
| 178 | + return |
| 179 | + |
| 180 | + # Log to NeuraCore |
| 181 | + nc.log_joint_positions(joint_positions_dict) |
| 182 | + nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_value) |
| 183 | + nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image) |
| 184 | + |
| 185 | + # Get policy prediction |
| 186 | + print("🎯 Getting policy prediction...") |
| 187 | + predictions = policy.predict(timeout=5) |
| 188 | + current_horizon = convert_predictions_to_horizon(predictions) |
| 189 | + current_action_idx = 0 |
| 190 | + playing = True |
| 191 | + print("FINISHED PREDICTION") |
| 192 | + |
| 193 | + # Save image to file |
| 194 | + image_pil = Image.fromarray(rgb_image) |
| 195 | + image_pil.save("current_image.png") |
| 196 | + print("💾 Saved image to current_image.png") |
| 197 | + |
| 198 | + # Update robot to initial pose |
| 199 | + joint_positions = np.array([joint_positions_dict[jn] for jn in JOINT_NAMES]) |
| 200 | + urdf_vis.update_cfg(joint_positions) |
| 201 | + |
| 202 | + print( |
| 203 | + f"✅ Prediction received: {len(current_horizon.get(JOINT_NAMES[0], []))} actions" |
| 204 | + ) |
| 205 | + |
| 206 | + |
| 207 | +# Add button |
| 208 | +random_button = server.gui.add_button("Random Selection") |
| 209 | +random_button.on_click(lambda _: select_random_state()) |
| 210 | + |
| 211 | +# Add gripper value display |
| 212 | +gripper_handle = server.gui.add_slider( |
| 213 | + "Gripper Open Amount", |
| 214 | + min=0.0, |
| 215 | + max=1.0, |
| 216 | + step=0.01, |
| 217 | + initial_value=0.0, |
| 218 | + disabled=True, # Read-only |
| 219 | +) |
| 220 | + |
| 221 | +# Add frequency control |
| 222 | +frequency_handle = server.gui.add_number( |
| 223 | + "Visualization Frequency (Hz)", |
| 224 | + initial_value=100.0, |
| 225 | + min=1.0, |
| 226 | + max=500.0, |
| 227 | + step=1.0, |
| 228 | +) |
| 229 | + |
| 230 | +# Select initial state |
| 231 | +select_random_state() |
| 232 | +# Main loop |
| 233 | +try: |
| 234 | + while True: |
| 235 | + start_time = time.time() |
| 236 | + |
| 237 | + # Update robot visualization |
| 238 | + if ( |
| 239 | + playing |
| 240 | + and current_horizon |
| 241 | + and len(current_horizon.get(JOINT_NAMES[0], [])) > 0 |
| 242 | + ): |
| 243 | + horizon_length = len(current_horizon[JOINT_NAMES[0]]) |
| 244 | + if current_action_idx < horizon_length: |
| 245 | + # Get current action |
| 246 | + joint_config = np.array( |
| 247 | + [ |
| 248 | + current_horizon[joint_name][current_action_idx] |
| 249 | + for joint_name in JOINT_NAMES |
| 250 | + ] |
| 251 | + ) |
| 252 | + urdf_vis.update_cfg(joint_config) |
| 253 | + |
| 254 | + # Update gripper value |
| 255 | + gripper_value = current_horizon[GRIPPER_LOGGING_NAME][ |
| 256 | + current_action_idx |
| 257 | + ] |
| 258 | + gripper_handle.value = round( |
| 259 | + gripper_value, 2 |
| 260 | + ) # Round to 2 decimal places |
| 261 | + |
| 262 | + # Advance to next action |
| 263 | + current_action_idx = (current_action_idx + 1) % horizon_length |
| 264 | + |
| 265 | + # Sleep to control update rate |
| 266 | + elapsed = time.time() - start_time |
| 267 | + frequency = max(frequency_handle.value, 0.1) # Avoid division by zero |
| 268 | + time.sleep(max(0, 1.0 / frequency - elapsed)) |
| 269 | + |
| 270 | +except KeyboardInterrupt: |
| 271 | + print("\n👋 Shutting down...") |
| 272 | +finally: |
| 273 | + policy.disconnect() |
| 274 | + nc.logout() |
0 commit comments