Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,15 @@ The downloaded dataset can be placed in the `/data` folder. The overall director
<br/>
<br/>

### :floppy_disk: Custom Data Preparation

If you want to fine-tune on your own dataset (non-RLDS), you can modify `vla-scripts/finetune.py`.
We provided a commented block in `finetune.py` (around line 932) that demonstrates how to swap the RLDS dataset for a standard PyTorch Dataset.
You will need to implement a dataset class that returns the appropriate dictionary format (input_ids, pixel_values, labels, etc.).

<br/>
<br/>

## ⚓ VLM backbone <a name="vlm"></a>
We use the `Prismatic-VLMs` architecture. Since the file is large, please download it from [here](https://huggingface.co/Stanford-ILIAD/prism-qwen25-extra-dinosiglip-224px-0_5b). Then put it in the `/pretrained_models` folder. The file structure is:

Expand Down
8 changes: 7 additions & 1 deletion experiments/robot/libero/run_libero_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
import draccus
import numpy as np
import tqdm
from libero.libero import benchmark
try:
from libero.libero import benchmark
except ImportError:
print("ERROR: Failed to import 'libero'. Please ensure you have installed the LIBERO benchmark dependencies.")
print("See README.md for installation instructions: https://github.com/Lifelong-Robot-Learning/LIBERO")
print("Run: pip install -e LIBERO and pip install -r experiments/robot/libero/libero_requirements.txt")
sys.exit(1)

import wandb

Expand Down
1 change: 1 addition & 0 deletions prismatic/vla/datasets/rlds/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def make_interleaved_dataset(
# Validation =>> fix a single shuffle buffer of data and cache it in RAM; prevents gradual memory increase!
if not train:
dataset = dataset.take(shuffle_buffer_size).cache()
dataset_len = min(dataset_len, shuffle_buffer_size)

# Shuffle the Dataset
# =>> IMPORTANT :: Shuffle AFTER .cache(), or else memory will still leak!
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ dependencies = [
"sentencepiece==0.1.99",
"timm==0.9.10",
"tokenizers==0.19.1",
"torch==2.2.0",
"torchvision==0.17.0",
"torchaudio==2.2.0",
"torch>=2.2.0",
"torchvision>=0.17.0",
"torchaudio>=2.2.0",
"transformers @ git+https://github.com/moojink/transformers-openvla-oft.git", # IMPORTANT: Use this fork for bidirectional attn (for parallel decoding)
"wandb",
"tensorflow==2.15.0",
Expand Down
4 changes: 3 additions & 1 deletion vla-scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class FinetuneConfig:

# Training configuration
batch_size: int = 8 # Batch size per device (total batch size = batch_size * num GPUs)
learning_rate: float = 5e-4 # Learning rate
learning_rate: float = 2e-4 # Learning rate
lr_warmup_steps: int = 0.1 # Number of steps to warm up learning rate (from 10% to 100%)
num_steps_before_decay: int = 100000 # Number of steps before LR decays by 10x
grad_accumulation_steps: int = 1 # Number of gradient accumulation steps
Expand Down Expand Up @@ -991,6 +991,7 @@ def rename_state_dict_keys(state_dict, replace_map):
sampler=None,
collate_fn=collator,
num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
pin_memory=True,
)
print('Len of dataloader: ', len(dataloader))
if cfg.use_val_set:
Expand All @@ -1001,6 +1002,7 @@ def rename_state_dict_keys(state_dict, replace_map):
sampler=None,
collate_fn=collator,
num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
pin_memory=True,
)

# Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation)
Expand Down