Skip to content

Commit 49e481e

Browse files
committed
[Feat] Support DataPacker for RL
1 parent 8cad75e commit 49e481e

File tree

2 files changed

+557
-0
lines changed

2 files changed

+557
-0
lines changed

tests/ray/test_pack.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import unittest
2+
import torch
3+
from xtuner.v1.data_proto.sequence_context import SequenceContext
4+
from xtuner.v1.rl.base.pack import RLDataPacker
5+
6+
class TestDataBatchPacker(unittest.TestCase):
7+
def setUp(self):
8+
self.pack_max_length = 3072
9+
self.split_size = 1024
10+
11+
def _create_dummy_item(self, length: int, val=1):
12+
input_ids = torch.full((1, length), val, dtype=torch.long)
13+
cu_seq_lens_q = torch.tensor([0, length], dtype=torch.int32)
14+
cu_seq_lens_k = torch.tensor([0, length], dtype=torch.int32)
15+
max_length_q = torch.tensor(length, dtype=torch.int32)
16+
max_length_k = torch.tensor(length, dtype=torch.int32)
17+
seq_ctx = SequenceContext(
18+
input_ids=input_ids,
19+
cu_seq_lens_q=cu_seq_lens_q,
20+
cu_seq_lens_k=cu_seq_lens_k,
21+
max_length_q=max_length_q,
22+
max_length_k=max_length_k,
23+
num_padding=0,
24+
device="cpu",
25+
)
26+
return {
27+
"seq_ctx": seq_ctx,
28+
"shifted_labels": torch.full((1, length), val, dtype=torch.long),
29+
"advantages": torch.full((1, length), float(val), dtype=torch.float),
30+
"rollout_logprobs": torch.full((1, length), float(val), dtype=torch.float),
31+
}
32+
33+
def _run_strategy_test(self, strategy, world_size, optimizer_steps, lengths, pack_max_length, expected_padding = None):
34+
data_batches = [self._create_dummy_item(l, val=7) for l in lengths]
35+
total_data_tokens = sum(lengths)
36+
37+
packer = RLDataPacker(
38+
pack_max_length=pack_max_length,
39+
world_size=world_size,
40+
data_replicate_size=1,
41+
optimizer_steps=optimizer_steps,
42+
pack_strategy=strategy
43+
)
44+
45+
packed_res, padding_tokens = packer.pack(data_batches)
46+
47+
# 验证均衡性:理想情况下,balance 策略分配给各卡的 token 总数差异应该小于单个样本的最大长度
48+
if strategy == "balance":
49+
rank_token_counts = []
50+
for rank_data in packed_res:
51+
rank_total_valid_tokens = 0
52+
for step_data in rank_data:
53+
for pack in step_data:
54+
# 统计非零(非 padding)的有效 token 数量
55+
valid_tokens = (pack["seq_ctx"].input_ids != 0).sum().item()
56+
rank_total_valid_tokens += valid_tokens
57+
rank_token_counts.append(rank_total_valid_tokens)
58+
59+
max_tokens = max(rank_token_counts)
60+
min_tokens = min(rank_token_counts)
61+
diff = max_tokens - min_tokens
62+
max_sample_len = max(lengths) if lengths else 0
63+
self.assertLessEqual(diff, max_sample_len,
64+
f"Balance strategy failed: Token distribution is too skewed. "
65+
f"Rank counts: {rank_token_counts}, Max diff: {diff}")
66+
67+
# 对于固定输入,验证padding_tokens是否符合预期来验证pack逻辑正确性
68+
if expected_padding is not None:
69+
self.assertEqual(padding_tokens, expected_padding, f"Strategy {strategy} padding mismatch. Expected {expected_padding}, got {padding_tokens}")
70+
71+
all_packs = []
72+
for rank_data in packed_res:
73+
for step_data in rank_data:
74+
for pack in step_data:
75+
self.assertEqual(pack["seq_ctx"].input_ids.numel(), pack_max_length, f"Strategy {strategy} pack length mismatch.")
76+
all_packs.append(pack)
77+
78+
# 验证pack前后的总有效token数是否一致
79+
total_capacity = len(all_packs) * pack_max_length
80+
self.assertEqual(total_capacity, total_data_tokens + padding_tokens)
81+
82+
all_input_ids = torch.cat([p["seq_ctx"].input_ids for p in all_packs], dim=1)
83+
valid_token_count = (all_input_ids != 0).sum().item()
84+
all_labels = torch.cat([p["shifted_labels"] for p in all_packs], dim=1)
85+
valid_label_count = (all_labels != -100).sum().item()
86+
all_advantages = torch.cat([p["advantages"] for p in all_packs], dim=1)
87+
valid_adv_count = (all_advantages != -100).sum().item()
88+
89+
self.assertEqual(valid_token_count, total_data_tokens)
90+
self.assertEqual(valid_label_count, total_data_tokens)
91+
self.assertEqual(valid_adv_count, total_data_tokens)
92+
93+
def test_variable_packs(self):
94+
"""随机tokens数输入, dp=2, optimizer_steps=2
95+
- Native:
96+
1. 预处理,保证样本数量能被整除, padding到1024, 这样可以与有效的样本一起Pack
97+
[1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800] -> padding: [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800, 1024]
98+
2. DP Rank 切分:
99+
rank0: [1500, 1000, 2800, 3000, 1500]
100+
rank1: [2000, 2100, 1000, 800, 1024]
101+
3. Optimizer steps切分:
102+
rank0: [1500, 1000, 2800], [3000, 1500]
103+
rank1: [2000, 2100, 1000], [ 800, 1024]
104+
4 pack and padding
105+
rank0: step0: [2500 -> 3072], [2800 -> 3072], step1: [3000 -> 3072], [1500 -> 3072],
106+
rank1: step0: [2000 -> 3072], [2100 -> 3072], [1000 -> 3072], step1: [1824 -> 3072]
107+
5. 跨卡对齐pack数量:
108+
rank0: step0: [2500 -> 3072], [2800 -> 3072], [0 -> 3072] step1: [3000 -> 3072], [1500 -> 3072],
109+
rank1: step0: [2100 -> 3072], [2000 -> 3072], [1000 -> 3072], step1: [1824 -> 3072], [0 -> 3072]
110+
padding_tokens: 1024 + 3072 - 2500 + 3072 - 2800 + 3072 + 3072 - 3000 + 3072 - 1500 + 3072 - 2100 + 3072 - 2000 + 3072 - 1000 + 3072 - 1824 + 3072 = 15020
111+
- Balance:
112+
1. 对原始输入数据进行排序:
113+
[1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800] -> [3000, 2800, 2100, 2000, 1500, 1500, 1000, 1000, 800]
114+
2. 相近长度的N个样本分到N张卡上, 每N个样本为作为N张卡的一次optimizer step的数据
115+
rank0: [3000, 1500, 800], [2100, 1000],
116+
rank1: [2800, 1500], [2000, 1000],
117+
3. pack and pad:
118+
rank0: step0: [3000 -> 3072], [2300 -> 3072], step1: [2100 ->3072], [1000 -> 3072],
119+
rank1: step0: [2800 -> 3072], [1500 -> 3072], step1: [3000 ->3072], [. 0 -> 3072],
120+
4. 跨卡对齐pack数量:
121+
skip
122+
padding_tokens: 3072 - 3000 + 3072 - 2300 + 3072 - 2100 + 3072 - 1000 + 3072 - 2800 + 3072 - 1500 + 3072 - 3000 + 3072 = 8876
123+
- Greedy: 追求 Pack 填充率最大化
124+
1. pack and padding:
125+
Pack 1: [1500, 1000] -> [2500 -> 3072]
126+
Pack 2: [2800] -> [2800 -> 3072]
127+
Pack 3: [3000] -> [3000 -> 3072]
128+
Pack 4: [1500] -> [1500 -> 3072]
129+
Pack 5: [2000] -> [2000 -> 3072]
130+
Pack 6: [2100] -> [2100 -> 3072]
131+
Pack 7: [1000, 800] -> [1800 -> 3072]
132+
Pack 8: [ ] -> [0 -> 3072] (padding)
133+
2. DP 切分:
134+
rank0: [Pack 1, Pack 2, Pack 3, Pack 4]
135+
rank1: [Pack 5, Pack 6, Pack 7, Pack 8]
136+
3. Opitmizer steps 切分:
137+
rank0: step0: [Pack 1, Pack 2], step1: [Pack 3, Pack 4]
138+
rank1: step0: [Pack 5, Pack 6], step1: [Pack 7, Pack 8]
139+
4. 跨卡对齐pack数量:
140+
skip
141+
padding_tokens: 3072 - 2500 + 3072 - 2800 + 3072 - 3000 + 3072 - 1500 + 3072 - 2000 + 3072 - 2100 + 3072 - 1800 + 3072 = 8876
142+
"""
143+
lengths = [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800]
144+
self._run_strategy_test("native", 2, 2, lengths, self.pack_max_length, 15020)
145+
self._run_strategy_test("balance", 2, 2, lengths, self.pack_max_length, 8876)
146+
self._run_strategy_test("greedy", 2, 2, lengths, self.pack_max_length, 8876)
147+
148+
def test_imbalance_dp_size(self):
149+
lengths = [500]
150+
for strat in ["native", "balance", "greedy"]:
151+
self._run_strategy_test(strat, 2, 1, lengths, self.pack_max_length, 5644)
152+
153+
def test_imbalanced_steps(self):
154+
lengths = [100, 200, 2500, 3000, 50, 400, 1000, 1500]
155+
self._run_strategy_test("native", 2, 4, lengths, self.pack_max_length, 15826)
156+
self._run_strategy_test("balance", 2, 4, lengths, self.pack_max_length, 15826)
157+
self._run_strategy_test("greedy", 2, 4, lengths, self.pack_max_length, 3538)
158+
159+
def test_random_lengths(self):
160+
import random
161+
lengths = [random.randint(1, 32768) for _ in range(1024)]
162+
for strat in ["native", "balance", "greedy"]:
163+
self._run_strategy_test(strat, 8, 16, lengths, 32768)
164+
165+
if __name__ == "__main__":
166+
unittest.main()

0 commit comments

Comments
 (0)