Skip to content

Commit a44a52b

Browse files
switch to gripper target open amount (#18)
1 parent 87f9fc6 commit a44a52b

8 files changed

+86
-43
lines changed

examples/2_collect_teleop_data_with_neuracore.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
DAMPING_COST,
3636
FRAME_TASK_GAIN,
3737
GRIPPER_FRAME_NAME,
38+
GRIPPER_LOGGING_NAME,
3839
IK_SOLVER_RATE,
40+
JOINT_NAMES,
3941
JOINT_STATE_STREAMING_RATE,
4042
LM_DAMPING,
4143
NEUTRAL_JOINT_ANGLES,
@@ -104,18 +106,25 @@ def neuracore_logging_worker(queue: Queue, worker_id: int) -> None:
104106
if function_name == "log_joint_positions":
105107
data_value = np.radians(data_value)
106108
data_dict = {
107-
f"joint{i+1}": angle for i, angle in enumerate(data_value)
109+
joint_name: angle
110+
for joint_name, angle in zip(JOINT_NAMES, data_value)
108111
}
109112
nc.log_joint_positions(data_dict, timestamp=timestamp)
110113
elif function_name == "log_joint_target_positions":
111114
data_value = np.radians(data_value)
112115
data_dict = {
113-
f"joint{i+1}": angle for i, angle in enumerate(data_value)
116+
joint_name: angle
117+
for joint_name, angle in zip(JOINT_NAMES, data_value)
114118
}
115119
nc.log_joint_target_positions(data_dict, timestamp=timestamp)
116120
elif function_name == "log_parallel_gripper_open_amounts":
117-
data_dict = {"gripper": data_value}
121+
data_dict = {GRIPPER_LOGGING_NAME: data_value}
118122
nc.log_parallel_gripper_open_amounts(data_dict, timestamp=timestamp)
123+
elif function_name == "log_parallel_gripper_target_open_amounts":
124+
data_dict = {GRIPPER_LOGGING_NAME: data_value}
125+
nc.log_parallel_gripper_target_open_amounts(
126+
data_dict, timestamp=timestamp
127+
)
119128
elif function_name == "log_rgb":
120129
camera_name = "rgb"
121130
image_array = data_value
@@ -327,7 +336,11 @@ def on_button_rj_pressed() -> None:
327336

328337
# Initialize Meta Quest reader
329338
print("\n🎮 Initializing Meta Quest reader...")
330-
quest_reader = MetaQuestReader(ip_address=args.ip_address, port=5555, run=True)
339+
quest_reader = MetaQuestReader(
340+
ip_address=args.ip_address,
341+
port=5555,
342+
run=True,
343+
)
331344

332345
# Register button callbacks (after state and robot_controller are initialized)
333346
quest_reader.on("button_a_pressed", on_button_a_pressed)

examples/3_replay_neuracore_episodes.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def main() -> None:
2828
"""Main function for replaying a Neuracore dataset on the Piper robot."""
2929
parser = argparse.ArgumentParser()
3030
parser.add_argument("--dataset-name", type=str, required=True)
31-
parser.add_argument("--frequency", type=int, required=False, default=100)
31+
parser.add_argument("--frequency", type=int, required=True)
3232
parser.add_argument("--episode-index", type=int, required=False, default=0)
3333
args = parser.parse_args()
3434

@@ -57,8 +57,10 @@ def main() -> None:
5757
print("\n🔁 Building robot data spec for synchronization...")
5858
data_types_to_synchronize = [
5959
DataType.JOINT_POSITIONS,
60+
DataType.JOINT_TARGET_POSITIONS,
6061
DataType.RGB_IMAGES,
6162
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
63+
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
6264
]
6365
robot_data_spec: RobotDataSpec = {}
6466
robot_ids_dataset = dataset.robot_ids
@@ -123,8 +125,8 @@ def main() -> None:
123125

124126
# Extract joint positions
125127
joint_positions_dict = {}
126-
if DataType.JOINT_POSITIONS in step.data:
127-
joint_data = step.data[DataType.JOINT_POSITIONS]
128+
if DataType.JOINT_TARGET_POSITIONS in step.data:
129+
joint_data = step.data[DataType.JOINT_TARGET_POSITIONS]
128130
for joint_name in JOINT_NAMES:
129131
if joint_name in joint_data:
130132
joint_positions_dict[joint_name] = joint_data[
@@ -134,8 +136,10 @@ def main() -> None:
134136

135137
# Extract gripper
136138
gripper_value = 0.0
137-
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in step.data:
138-
gripper_data = step.data[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
139+
if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in step.data:
140+
gripper_data = step.data[
141+
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS
142+
]
139143
if GRIPPER_LOGGING_NAME in gripper_data:
140144
gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount
141145
parallel_gripper_open_amounts.append(gripper_value)

examples/4_rollout_neuracore_policy.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[flo
8686
horizon[joint_name] = values
8787

8888
# Extract gripper open amounts
89-
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions:
90-
gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
89+
if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions:
90+
gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS]
9191
if GRIPPER_LOGGING_NAME in gripper_data:
9292
batched = gripper_data[GRIPPER_LOGGING_NAME]
9393
if isinstance(batched, BatchedParallelGripperOpenAmountData):
@@ -155,8 +155,8 @@ def run_policy(
155155
print("⚠️ No current joint angles available")
156156
return False
157157

158-
# Get target gripper open value because this is how the policy was trained
159-
gripper_open_value = data_manager.get_target_gripper_open_value()
158+
# Get current gripper open value
159+
gripper_open_value = data_manager.get_current_gripper_open_value()
160160
if gripper_open_value is None:
161161
print("⚠️ No gripper open value available")
162162
return False
@@ -170,7 +170,7 @@ def run_policy(
170170
# Prepare data for NeuraCore logging
171171
joint_angles_rad = np.radians(current_joint_angles)
172172
joint_positions_dict = {
173-
JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad)
173+
joint_name: angle for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad)
174174
}
175175

176176
# Log joint positions parallel gripper open amounts and RGB image to NeuraCore
@@ -455,10 +455,12 @@ def policy_execution_thread(
455455

456456
# Send current gripper open value to robot (if available)
457457
if GRIPPER_LOGGING_NAME in locked_horizon:
458-
current_gripper_open_value = locked_horizon[GRIPPER_LOGGING_NAME][
459-
execution_index
460-
]
461-
robot_controller.set_gripper_open_value(current_gripper_open_value)
458+
current_gripper_target_open_value = locked_horizon[
459+
GRIPPER_LOGGING_NAME
460+
][execution_index]
461+
robot_controller.set_gripper_open_value(
462+
current_gripper_target_open_value
463+
)
462464

463465
# Update execution index
464466
policy_state.increment_execution_action_index()
@@ -682,7 +684,7 @@ def update_visualization(
682684
}
683685
model_output_order = {
684686
DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES,
685-
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
687+
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
686688
}
687689

688690
print("\n📋 Model input order:")
@@ -721,8 +723,6 @@ def update_visualization(
721723
CONTROLLER_BETA,
722724
CONTROLLER_D_CUTOFF,
723725
)
724-
# Setting the target gripper so policy doesn't crash first time it runs
725-
data_manager.set_target_gripper_open_value(1.0)
726726

727727
# Initialize robot controller
728728
print("\n🤖 Initializing Piper robot controller...")

examples/5_rollout_neuracore_policy_minimal.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[flo
6060
horizon[joint_name] = values
6161

6262
# Extract gripper open amounts
63-
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions:
64-
gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
63+
if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions:
64+
gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS]
6565
if GRIPPER_LOGGING_NAME in gripper_data:
6666
batched = gripper_data[GRIPPER_LOGGING_NAME]
6767
if isinstance(batched, BatchedParallelGripperOpenAmountData):
@@ -79,8 +79,8 @@ def log_current_state(data_manager: DataManager) -> None:
7979
print("⚠️ No joint angles available")
8080
return
8181

82-
# Get target gripper open value because this is how the policy was trained
83-
gripper_open_value = data_manager.get_target_gripper_open_value()
82+
# Get current gripper open value
83+
gripper_open_value = data_manager.get_current_gripper_open_value()
8484
if gripper_open_value is None:
8585
print("⚠️ No gripper open value available")
8686
return
@@ -94,7 +94,7 @@ def log_current_state(data_manager: DataManager) -> None:
9494
# Prepare data for NeuraCore logging
9595
joint_angles_rad = np.radians(current_joint_angles)
9696
joint_positions_dict = {
97-
JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad)
97+
joint_name: angle for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad)
9898
}
9999

100100
# Log joint positions, parallel gripper open amounts, and RGB image to NeuraCore
@@ -123,8 +123,6 @@ def run_policy(
123123
horizon_length = policy_state.get_prediction_horizon_length()
124124
print(f"✓ Got {horizon_length} actions in {elapsed:.3f}s")
125125

126-
# Set execution ratio and save prediction horizon
127-
policy_state.set_execution_ratio(PREDICTION_HORIZON_EXECUTION_RATIO)
128126
policy_state.set_prediction_horizon(prediction_horizon)
129127
return True
130128

@@ -138,14 +136,15 @@ def execute_horizon(
138136
data_manager: DataManager,
139137
policy_state: PolicyState,
140138
robot_controller: PiperController,
139+
frequency: int,
141140
) -> None:
142141
"""Execute prediction horizon."""
143142
policy_state.start_policy_execution()
144143
data_manager.set_robot_activity_state(RobotActivityState.POLICY_CONTROLLED)
145144

146145
locked_horizon = policy_state.get_locked_prediction_horizon()
147146
horizon_length = policy_state.get_locked_prediction_horizon_length()
148-
dt = 1.0 / POLICY_EXECUTION_RATE
147+
dt = 1.0 / frequency
149148

150149
for i in range(horizon_length):
151150
start_time = time.time()
@@ -162,8 +161,8 @@ def execute_horizon(
162161

163162
# Send current gripper open value to robot (if available)
164163
if GRIPPER_LOGGING_NAME in locked_horizon:
165-
current_gripper_open_value = locked_horizon[GRIPPER_LOGGING_NAME][i]
166-
robot_controller.set_gripper_open_value(current_gripper_open_value)
164+
current_gripper_target_open_value = locked_horizon[GRIPPER_LOGGING_NAME][i]
165+
robot_controller.set_gripper_open_value(current_gripper_target_open_value)
167166

168167
# Log current state for visualization
169168
log_current_state(data_manager)
@@ -191,6 +190,18 @@ def execute_horizon(
191190
default=None,
192191
help="Path to local model file to load policy from. Mutually exclusive with --train-run-name.",
193192
)
193+
parser.add_argument(
194+
"--frequency",
195+
type=int,
196+
default=POLICY_EXECUTION_RATE,
197+
help="Frequency of policy execution",
198+
)
199+
parser.add_argument(
200+
"--execution-ratio",
201+
type=float,
202+
default=PREDICTION_HORIZON_EXECUTION_RATIO,
203+
help="Execution ratio of the policy",
204+
)
194205
args = parser.parse_args()
195206

196207
# Validate that exactly one of train-run-name or model-path is provided
@@ -223,7 +234,7 @@ def execute_horizon(
223234
}
224235
model_output_order = {
225236
DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES,
226-
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
237+
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
227238
}
228239

229240
print("\n📋 Model input order:")
@@ -252,8 +263,8 @@ def execute_horizon(
252263

253264
# Initialize state
254265
data_manager = DataManager()
255-
data_manager.set_target_gripper_open_value(1.0)
256266
policy_state = PolicyState()
267+
policy_state.set_execution_ratio(args.execution_ratio)
257268

258269
# Initialize robot controller
259270
print("\n🤖 Initializing robot controller...")
@@ -318,7 +329,9 @@ def execute_horizon(
318329
continue
319330

320331
# Execute horizon
321-
execute_horizon(data_manager, policy_state, robot_controller)
332+
execute_horizon(
333+
data_manager, policy_state, robot_controller, args.frequency
334+
)
322335

323336
except KeyboardInterrupt:
324337
print("\n\n👋 Interrupt received - shutting down...")

examples/6_visualize_policy_from_dataset.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#!/usr/bin/env python3
2-
"""Simple policy visualization from dataset - single script, no classes."""
2+
"""Simple policy visualization from dataset.
3+
4+
Loads a policy and a dataset, and randomly selects a state
5+
from the dataset to run the policy with and visualize the results.
6+
"""
37

48
import argparse
59
import random
@@ -40,6 +44,9 @@
4044
"--train-run-name", type=str, default=None, help="Training run name"
4145
)
4246
parser.add_argument("--model-path", type=str, default=None, help="Model file path")
47+
parser.add_argument(
48+
"--frequency", type=int, default=100, help="Frequency of visualization"
49+
)
4350
args = parser.parse_args()
4451

4552
if (args.train_run_name is None) == (args.model_path is None):
@@ -58,7 +65,7 @@
5865
}
5966
model_output_order = {
6067
DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES,
61-
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
68+
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
6269
}
6370

6471
if args.train_run_name:
@@ -90,7 +97,7 @@
9097

9198
print("🔁 Synchronizing dataset...")
9299
synced_dataset = dataset.synchronize(
93-
frequency=100,
100+
frequency=args.frequency,
94101
robot_data_spec=robot_data_spec,
95102
prefetch_videos=True,
96103
max_prefetch_workers=2,
@@ -126,8 +133,8 @@ def convert_predictions_to_horizon(
126133
if isinstance(batched, BatchedJointData):
127134
values = batched.value[0, :, 0].cpu().numpy().tolist()
128135
horizon[joint_name] = values
129-
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions:
130-
gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
136+
if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions:
137+
gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS]
131138
if GRIPPER_LOGGING_NAME in gripper_data:
132139
batched = gripper_data[GRIPPER_LOGGING_NAME]
133140
if isinstance(batched, BatchedParallelGripperOpenAmountData):
@@ -221,7 +228,7 @@ def select_random_state() -> None:
221228
# Add frequency control
222229
frequency_handle = server.gui.add_number(
223230
"Visualization Frequency (Hz)",
224-
initial_value=100.0,
231+
initial_value=args.frequency,
225232
min=1.0,
226233
max=500.0,
227234
step=1.0,

examples/common/configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
CAMERA_FRAME_STREAMING_RATE = 60.0 # Data collection rate for camera frame
4545

4646
# # Initial neutral pose for robot (degrees)
47-
NEUTRAL_JOINT_ANGLES = [-5.251, 21.356, -41.386, -4.323, 53.374, 0.0]
47+
NEUTRAL_JOINT_ANGLES = [-1.003, 80.167, -51.064, -4.127, 16.548, 2.619]
4848

4949
# Posture task cost vector (one weight per joint)
5050
POSTURE_COST_VECTOR = [0.0, 0.0, 0.0, 0.05, 0.0, 0.0]

examples/common/data_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,12 @@ def set_current_gripper_open_value(self, value: float) -> None:
406406
"""
407407
with self._robot_state._lock:
408408
self._robot_state.current_gripper_open_value = value
409+
if self._on_change_callback:
410+
self._on_change_callback(
411+
"log_parallel_gripper_open_amounts",
412+
value,
413+
time.time(),
414+
)
409415

410416
def get_target_gripper_open_value(self) -> float | None:
411417
"""Get target gripper open value (thread-safe).
@@ -426,7 +432,7 @@ def set_target_gripper_open_value(self, value: float) -> None:
426432
self._robot_state.target_gripper_open_value = value
427433
if self._on_change_callback:
428434
self._on_change_callback(
429-
"log_parallel_gripper_open_amounts",
435+
"log_parallel_gripper_target_open_amounts",
430436
self._robot_state.target_gripper_open_value,
431437
time.time(),
432438
)

0 commit comments

Comments
 (0)