1+ import unittest
2+ import torch
3+ from xtuner .v1 .data_proto .sequence_context import SequenceContext
4+ from xtuner .v1 .rl .base .pack import RLDataPacker
5+
6+ class TestDataBatchPacker (unittest .TestCase ):
7+ def setUp (self ):
8+ self .pack_max_length = 3072
9+ self .split_size = 1024
10+
11+ def _create_dummy_item (self , length : int , val = 1 ):
12+ input_ids = torch .full ((1 , length ), val , dtype = torch .long )
13+ cu_seq_lens_q = torch .tensor ([0 , length ], dtype = torch .int32 )
14+ cu_seq_lens_k = torch .tensor ([0 , length ], dtype = torch .int32 )
15+ max_length_q = torch .tensor (length , dtype = torch .int32 )
16+ max_length_k = torch .tensor (length , dtype = torch .int32 )
17+ seq_ctx = SequenceContext (
18+ input_ids = input_ids ,
19+ cu_seq_lens_q = cu_seq_lens_q ,
20+ cu_seq_lens_k = cu_seq_lens_k ,
21+ max_length_q = max_length_q ,
22+ max_length_k = max_length_k ,
23+ num_padding = 0 ,
24+ device = "cpu" ,
25+ )
26+ return {
27+ "seq_ctx" : seq_ctx ,
28+ "shifted_labels" : torch .full ((1 , length ), val , dtype = torch .long ),
29+ "advantages" : torch .full ((1 , length ), float (val ), dtype = torch .float ),
30+ "rollout_logprobs" : torch .full ((1 , length ), float (val ), dtype = torch .float ),
31+ }
32+
33+ def _run_strategy_test (self , strategy , world_size , optimizer_steps , lengths , pack_max_length , expected_padding = None ):
34+ data_batches = [self ._create_dummy_item (l , val = 7 ) for l in lengths ]
35+ total_data_tokens = sum (lengths )
36+
37+ packer = RLDataPacker (
38+ pack_max_length = pack_max_length ,
39+ world_size = world_size ,
40+ data_replicate_size = 1 ,
41+ optimizer_steps = optimizer_steps ,
42+ pack_strategy = strategy
43+ )
44+
45+ packed_res , padding_tokens = packer .pack (data_batches )
46+
47+ # 验证均衡性:理想情况下,balance 策略分配给各卡的 token 总数差异应该小于单个样本的最大长度
48+ if strategy == "balance" :
49+ rank_token_counts = []
50+ for rank_data in packed_res :
51+ rank_total_valid_tokens = 0
52+ for step_data in rank_data :
53+ for pack in step_data :
54+ # 统计非零(非 padding)的有效 token 数量
55+ valid_tokens = (pack ["seq_ctx" ].input_ids != 0 ).sum ().item ()
56+ rank_total_valid_tokens += valid_tokens
57+ rank_token_counts .append (rank_total_valid_tokens )
58+
59+ max_tokens = max (rank_token_counts )
60+ min_tokens = min (rank_token_counts )
61+ diff = max_tokens - min_tokens
62+ max_sample_len = max (lengths ) if lengths else 0
63+ self .assertLessEqual (diff , max_sample_len ,
64+ f"Balance strategy failed: Token distribution is too skewed. "
65+ f"Rank counts: { rank_token_counts } , Max diff: { diff } " )
66+
67+ # 对于固定输入,验证padding_tokens是否符合预期来验证pack逻辑正确性
68+ if expected_padding is not None :
69+ self .assertEqual (padding_tokens , expected_padding , f"Strategy { strategy } padding mismatch. Expected { expected_padding } , got { padding_tokens } " )
70+
71+ all_packs = []
72+ for rank_data in packed_res :
73+ for step_data in rank_data :
74+ for pack in step_data :
75+ self .assertEqual (pack ["seq_ctx" ].input_ids .numel (), pack_max_length , f"Strategy { strategy } pack length mismatch." )
76+ all_packs .append (pack )
77+
78+ # 验证pack前后的总有效token数是否一致
79+ total_capacity = len (all_packs ) * pack_max_length
80+ self .assertEqual (total_capacity , total_data_tokens + padding_tokens )
81+
82+ all_input_ids = torch .cat ([p ["seq_ctx" ].input_ids for p in all_packs ], dim = 1 )
83+ valid_token_count = (all_input_ids != 0 ).sum ().item ()
84+ all_labels = torch .cat ([p ["shifted_labels" ] for p in all_packs ], dim = 1 )
85+ valid_label_count = (all_labels != - 100 ).sum ().item ()
86+ all_advantages = torch .cat ([p ["advantages" ] for p in all_packs ], dim = 1 )
87+ valid_adv_count = (all_advantages != - 100 ).sum ().item ()
88+
89+ self .assertEqual (valid_token_count , total_data_tokens )
90+ self .assertEqual (valid_label_count , total_data_tokens )
91+ self .assertEqual (valid_adv_count , total_data_tokens )
92+
93+ def test_variable_packs (self ):
94+ """随机tokens数输入, dp=2, optimizer_steps=2
95+ - Native:
96+ 1. 预处理,保证样本数量能被整除, padding到1024, 这样可以与有效的样本一起Pack
97+ [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800] -> padding: [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800, 1024]
98+ 2. DP Rank 切分:
99+ rank0: [1500, 1000, 2800, 3000, 1500]
100+ rank1: [2000, 2100, 1000, 800, 1024]
101+ 3. Optimizer steps切分:
102+ rank0: [1500, 1000, 2800], [3000, 1500]
103+ rank1: [2000, 2100, 1000], [ 800, 1024]
104+ 4 pack and padding
105+ rank0: step0: [2500 -> 3072], [2800 -> 3072], step1: [3000 -> 3072], [1500 -> 3072],
106+ rank1: step0: [2000 -> 3072], [2100 -> 3072], [1000 -> 3072], step1: [1824 -> 3072]
107+ 5. 跨卡对齐pack数量:
108+ rank0: step0: [2500 -> 3072], [2800 -> 3072], [0 -> 3072] step1: [3000 -> 3072], [1500 -> 3072],
109+ rank1: step0: [2100 -> 3072], [2000 -> 3072], [1000 -> 3072], step1: [1824 -> 3072], [0 -> 3072]
110+ padding_tokens: 1024 + 3072 - 2500 + 3072 - 2800 + 3072 + 3072 - 3000 + 3072 - 1500 + 3072 - 2100 + 3072 - 2000 + 3072 - 1000 + 3072 - 1824 + 3072 = 15020
111+ - Balance:
112+ 1. 对原始输入数据进行排序:
113+ [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800] -> [3000, 2800, 2100, 2000, 1500, 1500, 1000, 1000, 800]
114+ 2. 相近长度的N个样本分到N张卡上, 每N个样本为作为N张卡的一次optimizer step的数据
115+ rank0: [3000, 1500, 800], [2100, 1000],
116+ rank1: [2800, 1500], [2000, 1000],
117+ 3. pack and pad:
118+ rank0: step0: [3000 -> 3072], [2300 -> 3072], step1: [2100 ->3072], [1000 -> 3072],
119+ rank1: step0: [2800 -> 3072], [1500 -> 3072], step1: [3000 ->3072], [. 0 -> 3072],
120+ 4. 跨卡对齐pack数量:
121+ skip
122+ padding_tokens: 3072 - 3000 + 3072 - 2300 + 3072 - 2100 + 3072 - 1000 + 3072 - 2800 + 3072 - 1500 + 3072 - 3000 + 3072 = 8876
123+ - Greedy: 追求 Pack 填充率最大化
124+ 1. pack and padding:
125+ Pack 1: [1500, 1000] -> [2500 -> 3072]
126+ Pack 2: [2800] -> [2800 -> 3072]
127+ Pack 3: [3000] -> [3000 -> 3072]
128+ Pack 4: [1500] -> [1500 -> 3072]
129+ Pack 5: [2000] -> [2000 -> 3072]
130+ Pack 6: [2100] -> [2100 -> 3072]
131+ Pack 7: [1000, 800] -> [1800 -> 3072]
132+ Pack 8: [ ] -> [0 -> 3072] (padding)
133+ 2. DP 切分:
134+ rank0: [Pack 1, Pack 2, Pack 3, Pack 4]
135+ rank1: [Pack 5, Pack 6, Pack 7, Pack 8]
136+ 3. Opitmizer steps 切分:
137+ rank0: step0: [Pack 1, Pack 2], step1: [Pack 3, Pack 4]
138+ rank1: step0: [Pack 5, Pack 6], step1: [Pack 7, Pack 8]
139+ 4. 跨卡对齐pack数量:
140+ skip
141+ padding_tokens: 3072 - 2500 + 3072 - 2800 + 3072 - 3000 + 3072 - 1500 + 3072 - 2000 + 3072 - 2100 + 3072 - 1800 + 3072 = 8876
142+ """
143+ lengths = [1500 , 1000 , 2800 , 3000 , 1500 , 2000 , 2100 , 1000 , 800 ]
144+ self ._run_strategy_test ("native" , 2 , 2 , lengths , self .pack_max_length , 15020 )
145+ self ._run_strategy_test ("balance" , 2 , 2 , lengths , self .pack_max_length , 8876 )
146+ self ._run_strategy_test ("greedy" , 2 , 2 , lengths , self .pack_max_length , 8876 )
147+
148+ def test_imbalance_dp_size (self ):
149+ lengths = [500 ]
150+ for strat in ["native" , "balance" , "greedy" ]:
151+ self ._run_strategy_test (strat , 2 , 1 , lengths , self .pack_max_length , 5644 )
152+
153+ def test_imbalanced_steps (self ):
154+ lengths = [100 , 200 , 2500 , 3000 , 50 , 400 , 1000 , 1500 ]
155+ self ._run_strategy_test ("native" , 2 , 4 , lengths , self .pack_max_length , 15826 )
156+ self ._run_strategy_test ("balance" , 2 , 4 , lengths , self .pack_max_length , 15826 )
157+ self ._run_strategy_test ("greedy" , 2 , 4 , lengths , self .pack_max_length , 3538 )
158+
159+ def test_random_lengths (self ):
160+ import random
161+ lengths = [random .randint (1 , 32768 ) for _ in range (1024 )]
162+ for strat in ["native" , "balance" , "greedy" ]:
163+ self ._run_strategy_test (strat , 8 , 16 , lengths , 32768 )
164+
165+ if __name__ == "__main__" :
166+ unittest .main ()
0 commit comments