Skip to content

Commit 3c9890e

Browse files
author
Your Name
committed
frame by frame
1 parent 06ffe9e commit 3c9890e

File tree

2 files changed

+52
-48
lines changed

2 files changed

+52
-48
lines changed

examples/droid_h5/simple_vlm_processing.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -318,56 +318,48 @@ def process_single_trajectory(
318318
else:
319319
selected_images = list(images)
320320

321-
# Create image grid for VLM analysis
322-
if num_frames_to_use <= 4:
323-
# Create 2x2 grid
324-
rows = 2
325-
cols = 2
326-
# Pad with copies if needed
327-
while len(selected_images) < 4:
328-
selected_images.append(selected_images[-1])
329-
else:
330-
# Create 2x3 grid
331-
rows = 2
332-
cols = 3
333-
# Pad with copies if needed
334-
while len(selected_images) < 6:
335-
selected_images.append(selected_images[-1])
336-
337-
resized_images = []
321+
# Prepare individual frames for VLM analysis
322+
processed_frames = []
338323
for img in selected_images:
339324
if len(img.shape) == 3: # RGB image
340-
# resized = cv2.resize(img, (target_width, target_height))
341-
resized_images.append(img)
325+
processed_frames.append(img)
342326
else:
343-
# Handle grayscale or other formats
344-
resized_images.append(np.zeros((target_height, target_width, 3), dtype=np.uint8))
345-
346-
# Create grid
347-
grid_rows = []
348-
for r in range(rows):
349-
row_images = resized_images[r * cols:(r + 1) * cols]
350-
grid_row = np.hstack(row_images)
351-
grid_rows.append(grid_row)
352-
353-
grid_image = np.vstack(grid_rows)
327+
# Handle grayscale or other formats - convert to RGB
328+
if len(img.shape) == 2: # Grayscale
329+
rgb_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
330+
processed_frames.append(rgb_img)
331+
else:
332+
# Default fallback
333+
processed_frames.append(np.zeros((480, 640, 3), dtype=np.uint8))
354334

355335
# Initialize VLM tools
356336
tools_manager = ToolsManager(config=tools_config)
357337

358338
# Get the VLM tool
359339
vlm_tool = tools_manager.get_tool("robo2vlm")
360340

361-
# Prepare VLM prompt aligned with droid_vlm_demo.py
341+
# Prepare VLM prompt for frame-by-frame analysis
362342
context = f"\nLanguage instruction: '{language_instruction}'" if language_instruction else ""
363343
traj_name = os.path.splitext(os.path.basename(trajectory_path))[0]
364344

345+
# Process frames individually and collect responses
346+
frame_responses = []
347+
for i, frame in enumerate(processed_frames):
348+
frame_prompt = f"""This is frame {i+1} of {len(processed_frames)} from a robot trajectory. Analyze what the robot is doing in this frame.{context}"""
349+
frame_response = vlm_tool(frame, frame_prompt)
350+
frame_responses.append(frame_response)
351+
print(f" 📸 Frame {i+1}/{len(processed_frames)} analyzed")
352+
353+
# Final analysis prompt combining all frame insights
354+
combined_analysis = "\n".join([f"Frame {i+1}: {resp}" for i, resp in enumerate(frame_responses)])
355+
final_prompt = f"""Based on the analysis of {len(processed_frames)} individual frames from this robot trajectory, does this trajectory look successful? First answer yes or no, then explain why.
365356
366-
# Align with droid_vlm_demo.py pattern for image analysis
367-
full_prompt = f"""These are {num_frames_to_use} frames from a robot trajectory. Does this trajectory look successful? First answer yes or no, then explain why.{context}"""
368-
369-
# Call VLM
370-
vlm_response = vlm_tool(grid_image, full_prompt)
357+
Frame-by-frame analysis:
358+
{combined_analysis}
359+
{context}"""
360+
361+
# Use the first frame for the final analysis call (the actual analysis is in the prompt)
362+
vlm_response = vlm_tool(processed_frames[0], final_prompt)
371363

372364
# Extract success prediction from VLM response (aligned with droid_vlm_demo.py)
373365
response_lower = vlm_response.lower()
@@ -390,24 +382,29 @@ def process_single_trajectory(
390382
os.makedirs(output_dir, exist_ok=True)
391383
results_dir = Path(output_dir)
392384

393-
# Save input image
394-
image_filename = results_dir / f"{traj_name}_input.jpg"
395-
cv2.imwrite(str(image_filename), cv2.cvtColor(grid_image, cv2.COLOR_RGB2BGR))
385+
# Save individual frames
386+
for i, frame in enumerate(processed_frames):
387+
frame_filename = results_dir / f"{traj_name}_frame_{i+1}.jpg"
388+
cv2.imwrite(str(frame_filename), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
396389

397390
# Save detailed results
398391
results_filename = results_dir / f"{traj_name}_results.txt"
399392
with open(results_filename, 'w') as f:
400-
f.write(f"VLM Processing Results\n")
401-
f.write(f"===================\n")
393+
f.write(f"VLM Processing Results (Frame-by-Frame)\n")
394+
f.write(f"======================================\n")
402395
f.write(f"Trajectory: {traj_name}\n")
403396
f.write(f"File path: {trajectory_path}\n")
404397
f.write(f"VLM prediction (success): {vlm_prediction}\n")
405398
f.write(f"Language instruction: {language_instruction or 'N/A'}\n")
406399
f.write(f"Frames analyzed: {num_frames_to_use}/{len(images)}\n")
407400
f.write(f"Used state visualization: {use_state_visualization}\n")
408-
f.write(f"\nVLM Prompt:\n{full_prompt}\n")
409-
f.write(f"\nVLM Response:\n{vlm_response}\n")
410-
f.write(f"\nInput image saved as: {traj_name}_input.jpg\n")
401+
f.write(f"\n--- Frame-by-Frame Analysis ---\n")
402+
for i, frame_resp in enumerate(frame_responses):
403+
f.write(f"\nFrame {i+1} Analysis:\n{frame_resp}\n")
404+
f.write(f"\n--- Final Analysis ---\n")
405+
f.write(f"Final Prompt:\n{final_prompt}\n")
406+
f.write(f"\nFinal VLM Response:\n{vlm_response}\n")
407+
f.write(f"\nFrames saved as: {traj_name}_frame_1.jpg to {traj_name}_frame_{len(processed_frames)}.jpg\n")
411408

412409
return {
413410
"trajectory_path": trajectory_path,
@@ -418,7 +415,9 @@ def process_single_trajectory(
418415
"language_instruction": language_instruction,
419416
"frames_analyzed": num_frames_to_use,
420417
"total_frames": len(images),
421-
"used_state_visualization": use_state_visualization
418+
"used_state_visualization": use_state_visualization,
419+
"frame_responses": frame_responses,
420+
"processing_method": "frame_by_frame"
422421
}
423422

424423
except Exception as e:
@@ -699,7 +698,7 @@ def main():
699698
if args.output_dir:
700699
print(f"\n📁 Detailed results saved to: {args.output_dir}/")
701700
print(f" - Individual result files: *_results.txt")
702-
print(f" - Input images: *_input.jpg")
701+
print(f" - Individual frame images: *_frame_N.jpg")
703702
print(f" - Processing summary: processing_summary.txt")
704703

705704
return 0

examples/droid_h5/validate_vlm_responses.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,11 @@ def validate_vlm_responses(
260260
# Process each result
261261
validated_results = []
262262
skipped_count = 0
263+
failed_processing_count = 0
263264

264265
for trajectory_path, result in results.items():
265266
if not result["success"]:
266-
skipped_count += 1
267+
failed_processing_count += 1
267268
continue
268269

269270
# Extract ground truth
@@ -316,12 +317,14 @@ def validate_vlm_responses(
316317
})
317318

318319
print(f"✅ Validated: {len(validated_results)}")
319-
print(f"⏩ Skipped: {skipped_count}")
320+
print(f"❌ Failed processing: {failed_processing_count}")
321+
print(f"⏩ Skipped (no ground truth): {skipped_count}")
320322

321323
if len(validated_results) == 0:
322324
return {
323325
"error": "No valid comparisons found",
324326
"total_processed": len(results),
327+
"failed_processing": failed_processing_count,
325328
"skipped": skipped_count
326329
}
327330

@@ -333,6 +336,7 @@ def validate_vlm_responses(
333336
return {
334337
"total_processed": len(results),
335338
"validated": len(validated_results),
339+
"failed_processing": failed_processing_count,
336340
"skipped": skipped_count,
337341
"metrics": metrics,
338342
"detailed_results": validated_results
@@ -434,6 +438,7 @@ def main():
434438
print("=" * 50)
435439
print(f"Total trajectories: {validation_results['total_processed']}")
436440
print(f"Successfully validated: {validation_results['validated']}")
441+
print(f"Failed processing: {validation_results['failed_processing']}")
437442
print(f"Skipped (no ground truth or prediction): {validation_results['skipped']}")
438443

439444
print(f"\n🎯 Accuracy Metrics:")

0 commit comments

Comments
 (0)