@@ -60,8 +60,8 @@ def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[flo
6060 horizon [joint_name ] = values
6161
6262 # Extract gripper open amounts
63- if DataType .PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions :
64- gripper_data = predictions [DataType .PARALLEL_GRIPPER_OPEN_AMOUNTS ]
63+ if DataType .PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions :
64+ gripper_data = predictions [DataType .PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS ]
6565 if GRIPPER_LOGGING_NAME in gripper_data :
6666 batched = gripper_data [GRIPPER_LOGGING_NAME ]
6767 if isinstance (batched , BatchedParallelGripperOpenAmountData ):
@@ -79,8 +79,8 @@ def log_current_state(data_manager: DataManager) -> None:
7979 print ("⚠️ No joint angles available" )
8080 return
8181
82- # Get target gripper open value because this is how the policy was trained
83- gripper_open_value = data_manager .get_target_gripper_open_value ()
82+ # Get current gripper open value
83+ gripper_open_value = data_manager .get_current_gripper_open_value ()
8484 if gripper_open_value is None :
8585 print ("⚠️ No gripper open value available" )
8686 return
@@ -94,7 +94,7 @@ def log_current_state(data_manager: DataManager) -> None:
9494 # Prepare data for NeuraCore logging
9595 joint_angles_rad = np .radians (current_joint_angles )
9696 joint_positions_dict = {
97- JOINT_NAMES [ i ] : angle for i , angle in enumerate ( joint_angles_rad )
97+ joint_name : angle for joint_name , angle in zip ( JOINT_NAMES , joint_angles_rad )
9898 }
9999
100100 # Log joint positions, parallel gripper open amounts, and RGB image to NeuraCore
@@ -123,8 +123,6 @@ def run_policy(
123123 horizon_length = policy_state .get_prediction_horizon_length ()
124124 print (f"✓ Got { horizon_length } actions in { elapsed :.3f} s" )
125125
126- # Set execution ratio and save prediction horizon
127- policy_state .set_execution_ratio (PREDICTION_HORIZON_EXECUTION_RATIO )
128126 policy_state .set_prediction_horizon (prediction_horizon )
129127 return True
130128
@@ -138,14 +136,15 @@ def execute_horizon(
138136 data_manager : DataManager ,
139137 policy_state : PolicyState ,
140138 robot_controller : PiperController ,
139+ frequency : int ,
141140) -> None :
142141 """Execute prediction horizon."""
143142 policy_state .start_policy_execution ()
144143 data_manager .set_robot_activity_state (RobotActivityState .POLICY_CONTROLLED )
145144
146145 locked_horizon = policy_state .get_locked_prediction_horizon ()
147146 horizon_length = policy_state .get_locked_prediction_horizon_length ()
148- dt = 1.0 / POLICY_EXECUTION_RATE
147+ dt = 1.0 / frequency
149148
150149 for i in range (horizon_length ):
151150 start_time = time .time ()
@@ -162,8 +161,8 @@ def execute_horizon(
162161
163162 # Send current gripper open value to robot (if available)
164163 if GRIPPER_LOGGING_NAME in locked_horizon :
165- current_gripper_open_value = locked_horizon [GRIPPER_LOGGING_NAME ][i ]
166- robot_controller .set_gripper_open_value (current_gripper_open_value )
164+ current_gripper_target_open_value = locked_horizon [GRIPPER_LOGGING_NAME ][i ]
165+ robot_controller .set_gripper_open_value (current_gripper_target_open_value )
167166
168167 # Log current state for visualization
169168 log_current_state (data_manager )
@@ -191,6 +190,18 @@ def execute_horizon(
191190 default = None ,
192191 help = "Path to local model file to load policy from. Mutually exclusive with --train-run-name." ,
193192 )
193+ parser .add_argument (
194+ "--frequency" ,
195+ type = int ,
196+ default = POLICY_EXECUTION_RATE ,
197+ help = "Frequency of policy execution" ,
198+ )
199+ parser .add_argument (
200+ "--execution-ratio" ,
201+ type = float ,
202+ default = PREDICTION_HORIZON_EXECUTION_RATIO ,
203+ help = "Execution ratio of the policy" ,
204+ )
194205 args = parser .parse_args ()
195206
196207 # Validate that exactly one of train-run-name or model-path is provided
@@ -223,7 +234,7 @@ def execute_horizon(
223234 }
224235 model_output_order = {
225236 DataType .JOINT_TARGET_POSITIONS : JOINT_NAMES ,
226- DataType .PARALLEL_GRIPPER_OPEN_AMOUNTS : [GRIPPER_LOGGING_NAME ],
237+ DataType .PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS : [GRIPPER_LOGGING_NAME ],
227238 }
228239
229240 print ("\n 📋 Model input order:" )
@@ -252,8 +263,8 @@ def execute_horizon(
252263
253264 # Initialize state
254265 data_manager = DataManager ()
255- data_manager .set_target_gripper_open_value (1.0 )
256266 policy_state = PolicyState ()
267+ policy_state .set_execution_ratio (args .execution_ratio )
257268
258269 # Initialize robot controller
259270 print ("\n 🤖 Initializing robot controller..." )
@@ -318,7 +329,9 @@ def execute_horizon(
318329 continue
319330
320331 # Execute horizon
321- execute_horizon (data_manager , policy_state , robot_controller )
332+ execute_horizon (
333+ data_manager , policy_state , robot_controller , args .frequency
334+ )
322335
323336 except KeyboardInterrupt :
324337 print ("\n \n 👋 Interrupt received - shutting down..." )
0 commit comments