Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions prismatic/extern/hf/modeling_prismatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ def _regression_or_discrete_prediction(
multi_layer_hidden_states = []

for item in language_model_output.hidden_states[0:]:
item = [:, 1:-1, :] # remove bos and eos token first
# last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
# Get hidden states for text portion of prompt+response (after the vision patches)
text_hidden_states = item
Expand Down
33 changes: 17 additions & 16 deletions vla-scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,10 @@ def run_forward_pass(
multi_layer_hidden_states = []

for item in output.hidden_states[0:]:
item = item[:, 1:-1, :] # remove bos token and eos token
# last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
# Get hidden states for text portion of prompt+response (after the vision patches)
text_hidden_states = item[:, num_patches:-1]
text_hidden_states = item[:, num_patches:, :]
# Get hidden states for action portion of response
batch_size = batch["input_ids"].shape[0]
# actions_hidden_states = text_hidden_states[:, -1, :].reshape(batch_size, 1, -1).to(torch.bfloat16)
Expand Down Expand Up @@ -1081,21 +1082,21 @@ def rename_state_dict_keys(state_dict, replace_map):
optimizer.zero_grad()
progress.update()

# Save model checkpoint: either keep latest checkpoint only or all checkpoints
if gradient_step_idx > 0 and log_step % cfg.save_freq == 0:
save_training_checkpoint(
cfg=cfg,
run_dir=run_dir,
log_step=log_step,
vla=vla,
processor=processor,
proprio_projector=proprio_projector if cfg.use_proprio else None,
noisy_action_projector=None,
action_head=action_head,
train_dataset=train_dataset,
distributed_state=distributed_state,
new_state_dict=RAW_STATE_DICT,
)
# Save model checkpoint: either keep latest checkpoint only or all checkpoints
if gradient_step_idx > 0 and log_step % cfg.save_freq == 0:
save_training_checkpoint(
cfg=cfg,
run_dir=run_dir,
log_step=log_step,
vla=vla,
processor=processor,
proprio_projector=proprio_projector if cfg.use_proprio else None,
noisy_action_projector=None,
action_head=action_head,
train_dataset=train_dataset,
distributed_state=distributed_state,
new_state_dict=RAW_STATE_DICT,
)

# Test model on validation set
if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0:
Expand Down