Skip to content

Commit e78a921

Browse files
committed
at least it runs
1 parent 15326f3 commit e78a921

File tree

7 files changed

+128
-370
lines changed

7 files changed

+128
-370
lines changed

examples/droid/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ python droid_to_robodm.py
3838

3939
- gsutil (for downloading from Google Cloud Storage)
4040
- RoboDM with vision tools enabled
41-
- VLM model (qwen2.5-7b by default)
41+
- VLM model (Llama 3.2-Vision2.5-7b by default)
4242

4343
## Sample Output
4444

examples/droid/droid_vlm_demo.py

Lines changed: 90 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
2-
Demo script using robo2vlm tool to classify DROID trajectories as successful or failed.
2+
Demo script using Llama 3.2-Vision model to classify DROID trajectories as successful or failed.
33
44
This script:
55
1. Downloads sample DROID trajectories (both success and failure)
66
2. Converts them to RoboDM format
7-
3. Uses the robo2vlm vision-language model to analyze trajectories
7+
3. Uses the Llama 3.2-Vision model to analyze trajectories
88
4. Demonstrates how to detect success/failure patterns
99
"""
1010

@@ -13,26 +13,102 @@
1313
from typing import Dict, List, Tuple
1414

1515
import numpy as np
16+
import torch
17+
from PIL import Image
18+
from transformers import MllamaForConditionalGeneration, AutoProcessor
1619
from download_droid import DROIDDownloader
1720
from droid_to_robodm import DROIDToRoboDMConverter
1821

1922
import robodm
20-
from robodm.agent.tools import ToolsManager, create_vision_config
2123

2224

2325
class DROIDSuccessDetector:
24-
"""Detect success/failure in DROID trajectories using VLM."""
26+
"""Detect success/failure in DROID trajectories using Llama 3.2-Vision."""
2527

2628
def __init__(self):
27-
# Initialize tools manager with vision config
28-
self.manager = ToolsManager(config=create_vision_config())
29-
self.vlm_tool = self.manager.get_tool("robo2vlm")
29+
# Initialize Llama 3.2-Vision model directly
30+
print("Loading Llama 3.2-Vision model...")
31+
self.model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
32+
33+
# Load model and processor
34+
self.model = MllamaForConditionalGeneration.from_pretrained(
35+
self.model_name,
36+
torch_dtype=torch.bfloat16,
37+
device_map="auto",
38+
trust_remote_code=True
39+
)
40+
41+
self.processor = AutoProcessor.from_pretrained(
42+
self.model_name,
43+
trust_remote_code=True
44+
)
45+
46+
print("Model loaded successfully!")
47+
48+
def analyze_frame_with_llama_vision(self, image: np.ndarray, prompt: str) -> str:
49+
"""
50+
Analyze a single frame using Llama 3.2-Vision.
51+
52+
Args:
53+
image: Frame as numpy array (H, W, C)
54+
prompt: Text prompt for analysis
55+
56+
Returns:
57+
Model response
58+
"""
59+
try:
60+
# Convert numpy array to PIL Image
61+
if image.dtype != np.uint8:
62+
image = (image * 255).astype(np.uint8)
63+
pil_image = Image.fromarray(image)
64+
65+
# Create conversation format for Llama 3.2-Vision
66+
messages = [
67+
{
68+
"role": "user",
69+
"content": [
70+
{"type": "image"},
71+
{"type": "text", "text": prompt}
72+
]
73+
}
74+
]
75+
76+
# Process inputs
77+
text = self.processor.apply_chat_template(
78+
messages, add_generation_prompt=True
79+
)
80+
81+
inputs = self.processor(
82+
images=[pil_image],
83+
text=text,
84+
return_tensors="pt"
85+
).to(self.model.device)
86+
87+
# Generate response
88+
with torch.no_grad():
89+
output = self.model.generate(
90+
**inputs,
91+
max_new_tokens=100,
92+
do_sample=False,
93+
temperature=0.1
94+
)
95+
96+
# Decode response (skip the input tokens)
97+
generated_ids = output[0][inputs.input_ids.shape[1]:]
98+
response = self.processor.decode(generated_ids, skip_special_tokens=True)
99+
100+
print(f"Response: {response.strip()}")
101+
return response.strip()
102+
103+
except Exception as e:
104+
print(f"Error analyzing frame: {e}")
105+
return "Error"
30106

31107
def analyze_trajectory_frames(self,
32108
trajectory_path: str,
33109
sample_rate: int = 10) -> Dict:
34110
"""
35-
Analyze frames from a trajectory using VLM.
111+
Analyze frames from a trajectory using Llama 3.2-Vision.
36112
37113
Args:
38114
trajectory_path: Path to RoboDM trajectory file
@@ -81,14 +157,8 @@ def analyze_trajectory_frames(self,
81157
frame_analysis = {"frame_idx": idx, "analyses": {}}
82158

83159
for prompt in prompts:
84-
try:
85-
response = self.vlm_tool(frame, prompt)
86-
frame_analysis["analyses"][prompt] = response
87-
except Exception as e:
88-
print(
89-
f"Error analyzing frame {idx} with prompt '{prompt}': {e}"
90-
)
91-
frame_analysis["analyses"][prompt] = "Error"
160+
response = self.analyze_frame_with_llama_vision(frame, prompt)
161+
frame_analysis["analyses"][prompt] = response
92162

93163
results["frame_analyses"].append(frame_analysis)
94164

@@ -269,8 +339,8 @@ def main():
269339
else:
270340
print(f"Using existing RoboDM trajectories in {robodm_dir}")
271341

272-
# Step 3: Analyze trajectories with VLM
273-
print("\n3. Analyzing trajectories with robo2vlm...")
342+
# Step 3: Analyze trajectories with Llama 3.2-Vision
343+
print("\n3. Analyzing trajectories with Llama 3.2-Vision...")
274344
detector = DROIDSuccessDetector()
275345

276346
# Get converted trajectory paths
@@ -285,11 +355,11 @@ def main():
285355

286356
print("\n" + "=" * 60)
287357
print(
288-
"Demo complete! The robo2vlm tool successfully analyzed DROID trajectories."
358+
"Demo complete! The Llama 3.2-Vision model successfully analyzed DROID trajectories."
289359
)
290360
print("\nKey insights:")
291361
print(
292-
"- VLM can detect task completion indicators in robotic trajectories")
362+
"- Llama 3.2-Vision can detect task completion indicators in robotic trajectories")
293363
print("- Success/failure patterns can be identified from visual analysis")
294364
print("- Frame-by-frame analysis provides detailed task understanding")
295365

0 commit comments

Comments
 (0)