-
Notifications
You must be signed in to change notification settings - Fork 168
Description
Fix Energon Support in Qwen3-VL
Summary
The current Qwen3-VL implementation has incomplete energon dataloader support that prevents proper distributed training with the energon data pipeline. This issue tracks the work needed to fix energon compatibility, particularly around multimodal data handling, parallelism modes (TP, SP, CP), and data distribution across DP ranks.
Problem Description
1. Task Encoder Limitations
The current QwenVLTaskEncoder (src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py) has several gaps compared to the reference implementation in Pai-Megatron-Patch:
Missing Features:
- Video Decoding: No proper
videohandlerfor decoding video data from webdataset format - Audio Support: Missing
audiohandlerfor audio modality (needed for omni-modal support) - Robust Sample Format: Current
ChatMLSampledoesn't fully leverage energon'sDefaultDecoderWebdatasetFactory
Current Implementation Issues:
# Current cook_chatml_sample only handles basic pickle deserialization
def cook_chatml_sample(sample: dict) -> ChatMLSample:
imgs = sample.get("jpgs", None)
if imgs:
imgs = pickle.loads(imgs)
# ...
videos = sample.get("videos", None)
if videos:
videos = pickle.loads(videos)
# ...Reference Implementation (Pai-Megatron-Patch):
# Uses proper decoder handlers for different media types
class ChatMLWebdataset(DefaultDecoderWebdatasetFactory[ChatMLSample]):
__sample_type__ = ChatMLSample
def __init__(self, path: EPath, *, auto_decode:bool =True, **kwargs):
super().__init__(path, auto_decode=auto_decode, **kwargs)
if auto_decode:
self._decoder = Decoder(
[
imagehandler(self.image_decode),
audiohandler(),
videohandler(self.image_decode),
]
)2. Parallelism Issues (TP, SP, CP with Sequence Packing)
Context Parallelism (CP) Not Supported:
The split_deepstack_embs function in src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py explicitly blocks CP:
def split_deepstack_embs(..., cp_size: int = 1, cp_rank: int = 0):
# first split by cp (zigzag)
assert cp_size == 1 and cp_rank == 0, "no support cp now" # <-- BLOCKINGTensor Parallelism with Visual Embeddings:
When sequence_parallel=True, embeddings are scattered after vision-text merge:
# In Qwen3VLModel.forward()
if self.config.sequence_parallel:
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings)However, this may cause issues when visual embeddings need to be split across TP ranks with sequence packing.
Sequence Packing + CP Issues in qwen3vl_step.py:
# Current implementation forces bshd format for Qwen3-VL
if pack_sequences_in_batch:
data_format = "thd"
if is_qwen3vl:
data_format = "bshd" # Forces bshd even with packingThis workaround may not properly handle CP slicing of packed sequences.
3. Data Distribution Across DP Ranks
Current Implementation in base_energon_datamodule.py:
def train_dataloader(self) -> Any:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
data_parallel_group = parallel_state.get_data_parallel_group()
worker_config = WorkerConfig(
rank=rank,
world_size=world_size,
num_workers=self.num_workers,
data_parallel_group=data_parallel_group,
# ...
)Required Behavior:
- ✅ Different DP ranks should receive different data (sharded by energon WorkerConfig)
- ✅ Same DP rank should receive same data (deterministic with seed + rank)
⚠️ Need to verify: When using CP > 1, all ranks within same DP group but different CP ranks should see the SAME data (since they process different sequence portions of the same batch)
Potential Issue:
The current implementation uses get_data_parallel_rank() which may not account for CP correctly. With CP, the "effective" DP group for data loading should consider CP ranks within the same DP replica.
Proposed Solution
Phase 1: Update Task Encoder
- Add Video Handler
class videohandler:
def __init__(self, imagespec):
self.extensions = ['jpgs', 'mp4s']
self.extensions_mapping = {"jpgs": "jpg", "mp4s": "jpg"}
self.image_handler = imagehandler(imagespec)
def __call__(self, key, data):
extension = re.sub(r".*[.]", "", key)
if extension.lower() not in self.extensions:
return None
data = pickle.loads(data)
key = self.extensions_mapping[extension]
if extension.lower() == 'jpgs':
data = [self.image_handler(key, d) for d in data]
else:
data = [[self.image_handler(key, d) for d in video] for video in data]
return data- Add Audio Handler
class audiohandler:
def __init__(self):
self.extensions = ['wavs', 'mp3s']
def __call__(self, key, data):
extension = re.sub(r".*[.]", "", key)
if extension not in self.extensions:
return None
data_list = pickle.loads(data)
audio_list = []
for data in data_list:
audio_list.append(torchaudio.load(io.BytesIO(data)))
return audio_list- Create
ChatMLWebdatasetFactory
class ChatMLWebdataset(DefaultDecoderWebdatasetFactory[ChatMLSample]):
__sample_type__ = ChatMLSample
def __init__(self, path: EPath, *, auto_decode: bool = True, **kwargs):
super().__init__(path, auto_decode=auto_decode, **kwargs)
if auto_decode:
self._decoder = Decoder([
imagehandler(self.image_decode),
audiohandler(),
videohandler(self.image_decode),
])Phase 2: Fix Parallelism Support
-
Enable CP in
split_deepstack_embs- Implement zigzag splitting pattern for visual embeddings across CP ranks
- Ensure visual token positions are consistent across CP ranks
-
Fix Sequence Packing + CP
- Properly partition packed sequences for CP using
thd_get_partitioned_indices - Ensure
cu_seqlensare correctly adjusted for each CP rank
- Properly partition packed sequences for CP using
-
Verify TP + SP with Visual Embeddings
- Ensure visual embeddings are correctly scattered when SP is enabled
- Handle variable-length visual token spans across TP shards
Phase 3: Fix DP Rank Data Distribution
- Update WorkerConfig for CP Awareness
def train_dataloader(self) -> Any:
# For data loading, we want same data across CP ranks within a DP group
# Use dp_cp_group rank instead of pure dp rank when CP > 1
if parallel_state.get_context_parallel_world_size() > 1:
# All CP ranks in same DP group should see same data
rank = parallel_state.get_data_parallel_rank() # DP rank only
world_size = parallel_state.get_data_parallel_world_size()
else:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
# ...- Add Verification Tests
- Test that DP rank 0 and DP rank 1 receive different batches
- Test that CP rank 0 and CP rank 1 within same DP group receive same input_ids
- Test reproducibility with same seed
Testing Plan
-
Unit Tests
- Test video/audio handler decoding
- Test ChatMLWebdataset with mock webdataset
- Test task encoder
encode_samplewith multimodal inputs
-
Integration Tests
- Test energon dataloader with TP=2, SP=True
- Test energon dataloader with CP=2
- Test sequence packing with CP enabled
-
Data Distribution Verification
- Log batch keys/hashes to verify DP rank sharding
- Verify same input_ids across CP ranks
Files to Modify
src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py- Add handlers, update sample formatsrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py- Enable CP insplit_deepstack_embssrc/megatron/bridge/training/qwen3vl_step.py- Fix CP handling with packingsrc/megatron/bridge/data/energon/base_energon_datamodule.py- Add CP-aware data distributionsrc/megatron/bridge/recipes/qwen_vl/qwen3_vl.py- Update recipe to use new task encoder
References
Priority
High - This blocks energon-based training for Qwen3-VL
Labels
bugenhancementdataqwen-vlmultimodal