11"""
2- Demo script using robo2vlm tool to classify DROID trajectories as successful or failed.
2+ Demo script using Llama 3.2-Vision model to classify DROID trajectories as successful or failed.
33
44This script:
551. Downloads sample DROID trajectories (both success and failure)
662. Converts them to RoboDM format
7- 3. Uses the robo2vlm vision-language model to analyze trajectories
7+ 3. Uses the Llama 3.2-Vision model to analyze trajectories
884. Demonstrates how to detect success/failure patterns
99"""
1010
1313from typing import Dict , List , Tuple
1414
1515import numpy as np
16+ import torch
17+ from PIL import Image
18+ from transformers import MllamaForConditionalGeneration , AutoProcessor
1619from download_droid import DROIDDownloader
1720from droid_to_robodm import DROIDToRoboDMConverter
1821
1922import robodm
20- from robodm .agent .tools import ToolsManager , create_vision_config
2123
2224
2325class DROIDSuccessDetector :
24- """Detect success/failure in DROID trajectories using VLM ."""
26+ """Detect success/failure in DROID trajectories using Llama 3.2-Vision ."""
2527
2628 def __init__ (self ):
27- # Initialize tools manager with vision config
28- self .manager = ToolsManager (config = create_vision_config ())
29- self .vlm_tool = self .manager .get_tool ("robo2vlm" )
29+ # Initialize Llama 3.2-Vision model directly
30+ print ("Loading Llama 3.2-Vision model..." )
31+ self .model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
32+
33+ # Load model and processor
34+ self .model = MllamaForConditionalGeneration .from_pretrained (
35+ self .model_name ,
36+ torch_dtype = torch .bfloat16 ,
37+ device_map = "auto" ,
38+ trust_remote_code = True
39+ )
40+
41+ self .processor = AutoProcessor .from_pretrained (
42+ self .model_name ,
43+ trust_remote_code = True
44+ )
45+
46+ print ("Model loaded successfully!" )
47+
48+ def analyze_frame_with_llama_vision (self , image : np .ndarray , prompt : str ) -> str :
49+ """
50+ Analyze a single frame using Llama 3.2-Vision.
51+
52+ Args:
53+ image: Frame as numpy array (H, W, C)
54+ prompt: Text prompt for analysis
55+
56+ Returns:
57+ Model response
58+ """
59+ try :
60+ # Convert numpy array to PIL Image
61+ if image .dtype != np .uint8 :
62+ image = (image * 255 ).astype (np .uint8 )
63+ pil_image = Image .fromarray (image )
64+
65+ # Create conversation format for Llama 3.2-Vision
66+ messages = [
67+ {
68+ "role" : "user" ,
69+ "content" : [
70+ {"type" : "image" },
71+ {"type" : "text" , "text" : prompt }
72+ ]
73+ }
74+ ]
75+
76+ # Process inputs
77+ text = self .processor .apply_chat_template (
78+ messages , add_generation_prompt = True
79+ )
80+
81+ inputs = self .processor (
82+ images = [pil_image ],
83+ text = text ,
84+ return_tensors = "pt"
85+ ).to (self .model .device )
86+
87+ # Generate response
88+ with torch .no_grad ():
89+ output = self .model .generate (
90+ ** inputs ,
91+ max_new_tokens = 100 ,
92+ do_sample = False ,
93+ temperature = 0.1
94+ )
95+
96+ # Decode response (skip the input tokens)
97+ generated_ids = output [0 ][inputs .input_ids .shape [1 ]:]
98+ response = self .processor .decode (generated_ids , skip_special_tokens = True )
99+
100+ print (f"Response: { response .strip ()} " )
101+ return response .strip ()
102+
103+ except Exception as e :
104+ print (f"Error analyzing frame: { e } " )
105+ return "Error"
30106
31107 def analyze_trajectory_frames (self ,
32108 trajectory_path : str ,
33109 sample_rate : int = 10 ) -> Dict :
34110 """
35- Analyze frames from a trajectory using VLM .
111+ Analyze frames from a trajectory using Llama 3.2-Vision .
36112
37113 Args:
38114 trajectory_path: Path to RoboDM trajectory file
@@ -81,14 +157,8 @@ def analyze_trajectory_frames(self,
81157 frame_analysis = {"frame_idx" : idx , "analyses" : {}}
82158
83159 for prompt in prompts :
84- try :
85- response = self .vlm_tool (frame , prompt )
86- frame_analysis ["analyses" ][prompt ] = response
87- except Exception as e :
88- print (
89- f"Error analyzing frame { idx } with prompt '{ prompt } ': { e } "
90- )
91- frame_analysis ["analyses" ][prompt ] = "Error"
160+ response = self .analyze_frame_with_llama_vision (frame , prompt )
161+ frame_analysis ["analyses" ][prompt ] = response
92162
93163 results ["frame_analyses" ].append (frame_analysis )
94164
@@ -269,8 +339,8 @@ def main():
269339 else :
270340 print (f"Using existing RoboDM trajectories in { robodm_dir } " )
271341
272- # Step 3: Analyze trajectories with VLM
273- print ("\n 3. Analyzing trajectories with robo2vlm ..." )
342+ # Step 3: Analyze trajectories with Llama 3.2-Vision
343+ print ("\n 3. Analyzing trajectories with Llama 3.2-Vision ..." )
274344 detector = DROIDSuccessDetector ()
275345
276346 # Get converted trajectory paths
@@ -285,11 +355,11 @@ def main():
285355
286356 print ("\n " + "=" * 60 )
287357 print (
288- "Demo complete! The robo2vlm tool successfully analyzed DROID trajectories."
358+ "Demo complete! The Llama 3.2-Vision model successfully analyzed DROID trajectories."
289359 )
290360 print ("\n Key insights:" )
291361 print (
292- "- VLM can detect task completion indicators in robotic trajectories" )
362+ "- Llama 3.2-Vision can detect task completion indicators in robotic trajectories" )
293363 print ("- Success/failure patterns can be identified from visual analysis" )
294364 print ("- Frame-by-frame analysis provides detailed task understanding" )
295365
0 commit comments