Skip to content

Commit 62ae9fc

Browse files
committed
[Refactor] refactor packing in RL train controller and train worker
1 parent 7c8f82c commit 62ae9fc

File tree

3 files changed

+544
-202
lines changed

3 files changed

+544
-202
lines changed

xtuner/v1/rl/base/controller.py

Lines changed: 234 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import math
2-
from typing import Literal, TypedDict
2+
from typing import Literal, TypedDict, cast
33

4+
import numpy as np
45
import ray
56
import torch
67
from ray.actor import ActorProxy
78

89
from xtuner.v1.data_proto.sequence_context import SequenceContext
910
from xtuner.v1.model.compose.base import BaseComposeConfig
11+
from xtuner.v1.rl.utils import get_seqlen_balanced_partitions
1012
from xtuner.v1.train.trainer import LoadCheckpointConfig
11-
from xtuner.v1.utils import ray_method
13+
from xtuner.v1.utils import get_logger, ray_method
1214

1315
from .worker import TrainingWorker
1416

@@ -23,6 +25,12 @@ class ColateItem(TypedDict):
2325
class RawTrainingController:
2426
def __init__(self, workers: list[TrainingWorker]) -> None:
2527
self.workers = workers
28+
refs = [
29+
self.workers[0].get_model_cfg.remote(),
30+
self.workers[0].get_worker_cfg.remote(),
31+
self.workers[0].get_data_replicate_size.remote(),
32+
]
33+
self.model_cfg, self.worker_cfg, self.data_replicate_size = ray.get(refs)
2634

2735
# TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
2836
def _get_pack_infos(self, dataset, num_tokens, target, random=None):
@@ -164,95 +172,241 @@ def _grouped_by_max_length(self, packed_data_batches):
164172
# 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。
165173
return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True)
166174

167-
@ray_method
168-
def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int):
169-
has_rollout_routed_experts = False
170-
language_cfg = None
171-
if data_batches[0]["seq_ctx"].rollout_routed_experts is not None:
172-
model_cfg = ray.get(self.workers[0].get_model_cfg.remote()) # type: ignore[attr-defined]
173-
has_rollout_routed_experts = True
174-
language_cfg = model_cfg
175-
if isinstance(model_cfg, BaseComposeConfig):
176-
language_cfg = model_cfg.text_config
177-
178-
packed_data_batches = self._packing(data_batches, pack_max_length, language_cfg)
179-
# packed_data_batches = self._grouped_by_max_length(packed_data_batches)
180-
181-
# TODO(hha): 这个逻辑不够通用,和模型绑定了
182-
is_qwen3_vl = False
183-
if len(packed_data_batches[0]["seq_ctx"].position_ids.shape) == 3:
184-
is_qwen3_vl = True
175+
def _balance_split_batch(self, data_batches, partition_size):
176+
"""Reorder the data on single controller such that each dp rank gets
177+
similar total tokens."""
178+
global_seqlen_lst = [data["seq_ctx"].input_ids.numel() for data in data_batches]
179+
global_partition_lst = get_seqlen_balanced_partitions(
180+
global_seqlen_lst, k_partitions=partition_size, equal_size=True
181+
)
182+
balanced_batches = []
183+
tokens_in_partition = []
184+
for partition in global_partition_lst:
185+
partition_batch = [data_batches[i] for i in partition]
186+
tokens_in_partition.append(sum(data["seq_ctx"].input_ids.numel() for data in partition_batch))
187+
balanced_batches.append(partition_batch)
188+
get_logger().info(f"Balanced split into {partition_size} partitions with tokens: {tokens_in_partition}")
189+
return balanced_batches
190+
191+
def _create_padding_sample(
192+
self,
193+
pad_len: int,
194+
pack_max_length: int,
195+
is_qwen3_vl: bool = False,
196+
has_rollout_routed_experts: bool = False,
197+
has_rollout_logprobs: bool = True,
198+
n_routed_experts: int | None = None,
199+
split_size: int = 1024,
200+
):
201+
# padding input_ids
202+
pad_tokens = tuple(
203+
torch.zeros(1, split_size, dtype=torch.long, device="cpu") for _ in range(pad_len // split_size)
204+
)
205+
if pad_len % split_size > 0:
206+
pad_tokens = pad_tokens + (torch.zeros(1, pad_len % split_size, dtype=torch.long, device="cpu"),)
207+
pad_tokens = cast(tuple[torch.LongTensor, ...], pad_tokens)
208+
pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu")
209+
pad_seq_ctx.num_padding = pad_len
210+
211+
# padding mm positions_ids
212+
if is_qwen3_vl:
213+
_position_ids_list = []
214+
for pad_token in pad_tokens:
215+
_position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1)
216+
_position_ids_list.append(_position_ids)
217+
position_ids = torch.cat(_position_ids_list, dim=-1)
218+
position_ids = cast(torch.LongTensor, position_ids)
219+
pad_seq_ctx.position_ids = position_ids
220+
221+
# padding rollout routed experts
222+
if has_rollout_routed_experts:
223+
assert n_routed_experts, "n_routed_experts must be provided when has_rollout_routed_experts is True"
224+
if pad_len == pack_max_length:
225+
pad_rand_index = torch.randint(
226+
low=0, high=1, size=(1, 1, 1)
227+
) # add dummy data, true data will be initialized in train worker.fit
228+
else:
229+
pad_rand_index = torch.randint(low=0, high=n_routed_experts, size=(pad_len, 1, 1))
230+
pad_seq_ctx.rollout_routed_experts = pad_rand_index
231+
232+
pad_labels = torch.full((1, pad_len), -100, dtype=torch.long, device="cpu")
233+
pad_advantage_length = pack_max_length if pad_len == pack_max_length else math.ceil(pad_len / 1024)
234+
pad_advantage = torch.full(
235+
(1, pad_advantage_length),
236+
-100,
237+
dtype=torch.float32,
238+
device="cpu",
239+
)
240+
pad_rollout_logprobs = (
241+
torch.zeros(1, pad_len, dtype=torch.float32, device="cpu") if has_rollout_logprobs else None
242+
)
185243

186-
# todo: support round up
187-
num_packed_data_batches = len(packed_data_batches)
188-
data_replicate_size = ray.get(self.workers[0].get_data_replicate_size.remote()) # type: ignore[attr-defined]
189-
dp_size = len(self.workers) // data_replicate_size
190-
pad_num = math.ceil(num_packed_data_batches / dp_size) * dp_size - num_packed_data_batches
191-
if pad_num > 0:
192-
# Reduce the attn calculation time by using multiple short sequence packs
193-
assert data_batches[0]["seq_ctx"].input_ids is not None
194-
pad_tokens = tuple(
195-
torch.zeros(1, 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu")
196-
for _ in range(pack_max_length // 1024)
197-
)
198-
if pack_max_length % 1024 > 0:
199-
assert data_batches[0]["seq_ctx"].input_ids is not None
200-
pad_tokens = pad_tokens + (
201-
torch.zeros(
202-
1, pack_max_length % 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu"
203-
),
204-
)
205-
pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") # type: ignore
206-
pad_seq_ctx.num_padding = pack_max_length
207-
if is_qwen3_vl:
208-
_position_ids_list = []
209-
for pad_token in pad_tokens:
210-
_position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1)
211-
_position_ids_list.append(_position_ids)
212-
pad_seq_ctx.position_ids = torch.cat(_position_ids_list, dim=-1) # type: ignore
213-
214-
pad_shifted_labels = torch.full(
215-
(1, pack_max_length),
216-
-100,
217-
dtype=packed_data_batches[0]["shifted_labels"].dtype,
218-
device="cpu",
219-
)
220-
pad_advantages = torch.full(
221-
(1, pack_max_length),
222-
-100,
223-
dtype=packed_data_batches[0]["advantages"].dtype,
224-
device="cpu",
244+
return {
245+
"seq_ctx": pad_seq_ctx,
246+
"shifted_labels": pad_labels,
247+
"advantages": pad_advantage,
248+
"rollout_logprobs": pad_rollout_logprobs,
249+
}
250+
251+
def _pack(self, mini_batch, pack_max_length):
252+
assert len(mini_batch) > 0, "mini_batch should not be empty"
253+
seqlen_list = []
254+
for data in mini_batch:
255+
assert data["seq_ctx"].input_ids.numel() <= pack_max_length, (
256+
f"Single sample seq len {data['seq_ctx'].input_ids.numel()} exceeds pack_max_length {pack_max_length}"
225257
)
258+
seqlen_list.append(data["seq_ctx"].input_ids.numel())
259+
total_length = sum(seqlen_list)
226260

227-
if has_rollout_routed_experts:
228-
pad_rand_index = torch.randint(
229-
low=0,
230-
high=1,
231-
size=(1, 1, 1), # add dummy data, true data will be initialized in train worker.fit
232-
)
233-
pad_seq_ctx.rollout_routed_experts = pad_rand_index
261+
if total_length <= pack_max_length:
262+
return [mini_batch] # No packing needed
234263

235-
pad_rollout_logprobs = None
236-
if "rollout_logprobs" in packed_data_batches[0] and packed_data_batches[0]["rollout_logprobs"] is not None:
237-
pad_rollout_logprobs = torch.zeros(
238-
1, pack_max_length, dtype=packed_data_batches[0]["rollout_logprobs"].dtype, device="cpu"
239-
)
240-
pad_data = {
241-
"seq_ctx": pad_seq_ctx,
242-
"shifted_labels": pad_shifted_labels,
243-
"advantages": pad_advantages,
244-
"rollout_logprobs": pad_rollout_logprobs,
264+
num_packs = math.ceil(total_length / pack_max_length)
265+
partitions_indices = get_seqlen_balanced_partitions(
266+
seqlen_list=seqlen_list, k_partitions=num_packs, equal_size=False
267+
)
268+
269+
packed_mini_batches = []
270+
for partition in partitions_indices:
271+
packed_batch = [mini_batch[i] for i in partition]
272+
packed_mini_batches.append(packed_batch)
273+
return packed_mini_batches
274+
275+
def _get_data_batches_properties(self, data_batches: list[ColateItem]):
276+
"""Extract properties from the first element of data_batches."""
277+
if not data_batches:
278+
return {
279+
"is_qwen3_vl": False,
280+
"has_rollout_routed_experts": False,
281+
"has_rollout_logprobs": False,
282+
"n_routed_experts": None,
245283
}
246-
pad_data_samples = [pad_data for _ in range(pad_num)]
247-
packed_data_batches = packed_data_batches + pad_data_samples
248284

249-
print(f"len(packed_data_batches): {len(packed_data_batches)}")
285+
first_item = data_batches[0]
286+
seq_ctx = first_item["seq_ctx"]
287+
288+
is_qwen3_vl = seq_ctx.position_ids is not None and len(seq_ctx.position_ids.shape) == 3
289+
has_rollout_logprobs = "rollout_logprobs" in first_item and first_item["rollout_logprobs"] is not None
290+
has_rollout_routed_experts = seq_ctx.rollout_routed_experts is not None
291+
292+
language_cfg = None
293+
if has_rollout_routed_experts:
294+
language_cfg = self.model_cfg
295+
if isinstance(self.model_cfg, BaseComposeConfig):
296+
language_cfg = self.model_cfg.text_config
297+
298+
return {
299+
"is_qwen3_vl": is_qwen3_vl,
300+
"has_rollout_routed_experts": has_rollout_routed_experts,
301+
"has_rollout_logprobs": has_rollout_logprobs,
302+
"n_routed_experts": language_cfg.n_routed_experts if language_cfg is not None else None,
303+
}
304+
305+
@ray_method
306+
def fit(
307+
self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int, enable_dp_balance: bool = True
308+
):
309+
batch_props = self._get_data_batches_properties(data_batches)
310+
is_qwen3_vl = batch_props["is_qwen3_vl"]
311+
has_rollout_routed_experts = batch_props["has_rollout_routed_experts"]
312+
has_rollout_logprobs = batch_props["has_rollout_logprobs"]
313+
n_routed_experts = batch_props["n_routed_experts"]
314+
315+
world_size = len(self.workers)
316+
dp_size = world_size // self.data_replicate_size
317+
assert world_size % self.data_replicate_size == 0, "world_size must be divisible by data_replicate_size"
318+
optimizer_steps = self.worker_cfg.optimizer_steps
319+
320+
if enable_dp_balance:
321+
# 按照 dp_size 对数据进行重新分配,保证每个 dp rank 上的 token 数量大致相同
322+
batches_per_dp_group = self._balance_split_batch(data_batches, dp_size)
323+
else:
324+
batches_per_dp_group = np.array_split(data_batches, dp_size)
325+
tokens_in_partition = []
326+
for batch in batches_per_dp_group:
327+
tokens_in_partition.append(sum(data["seq_ctx"].input_ids.numel() for data in batch))
328+
get_logger().info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}")
329+
330+
packed_data_batches: list[list[list[dict]]] = [[[] for _ in range(optimizer_steps)] for _ in range(dp_size)]
331+
max_packs_per_card = [0] * optimizer_steps
332+
333+
for dp_rank, dp_worker_data_batches in enumerate(batches_per_dp_group):
334+
# 每个worker 内部按照optimizer_steps将token均分
335+
mini_batch_for_steps = self._balance_split_batch(dp_worker_data_batches, optimizer_steps)
336+
337+
for step_idx, step_mini_batch in enumerate(mini_batch_for_steps):
338+
# pack
339+
pack_mini_batch = self._pack(step_mini_batch, pack_max_length)
340+
if len(pack_mini_batch) > max_packs_per_card[step_idx]:
341+
max_packs_per_card[step_idx] = len(pack_mini_batch)
342+
343+
for pack in pack_mini_batch:
344+
seq_ctx_list = [item["seq_ctx"] for item in pack]
345+
label_list = [item["shifted_labels"] for item in pack]
346+
advantage_list = [torch.tensor([item["advantage"]]).float().unsqueeze(0) for item in pack]
347+
rollout_logprobs_list = [
348+
item["rollout_logprobs"] if has_rollout_logprobs else None for item in pack
349+
]
350+
padding_len = pack_max_length - sum([item["seq_ctx"].input_ids.numel() for item in pack])
351+
if padding_len > 0:
352+
padding_sample = self._create_padding_sample(
353+
padding_len,
354+
pack_max_length,
355+
is_qwen3_vl=is_qwen3_vl,
356+
has_rollout_routed_experts=has_rollout_routed_experts,
357+
has_rollout_logprobs=has_rollout_logprobs,
358+
n_routed_experts=n_routed_experts,
359+
)
360+
seq_ctx_list.append(padding_sample["seq_ctx"])
361+
label_list.append(padding_sample["shifted_labels"])
362+
advantage_list.append(padding_sample["advantages"])
363+
rollout_logprobs_list.append(padding_sample["rollout_logprobs"])
364+
365+
packed_seq_ctx = SequenceContext.pack(seq_ctx_list)
366+
packed_shifted_labels = torch.cat(label_list, dim=1)
367+
cu_seq_lens_q = packed_seq_ctx.cu_seq_lens_q
368+
packed_num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
369+
packed_advantages = torch.cat(advantage_list, dim=1)
370+
packed_advantages = torch.repeat_interleave(packed_advantages, packed_num_tokens, dim=1)
371+
if has_rollout_logprobs:
372+
cast_rollout_logprobs_list = [cast(torch.Tensor, item) for item in rollout_logprobs_list]
373+
packed_rollout_logprobs = torch.cat(cast_rollout_logprobs_list, dim=1)
374+
else:
375+
packed_rollout_logprobs = None
376+
packed_data_batches[dp_rank][step_idx].append(
377+
{
378+
"seq_ctx": packed_seq_ctx,
379+
"shifted_labels": packed_shifted_labels,
380+
"advantages": packed_advantages,
381+
"rollout_logprobs": packed_rollout_logprobs,
382+
}
383+
)
384+
385+
get_logger().info(f"Gradient accumulation steps: {max_packs_per_card}")
386+
# padding for each worker to have same number of packs
387+
for dp_rank in range(dp_size):
388+
for step_idx in range(optimizer_steps):
389+
max_packs = max_packs_per_card[step_idx]
390+
num_current_packs = len(packed_data_batches[dp_rank][step_idx])
391+
num_padding_packs = max_packs - num_current_packs
392+
393+
if num_padding_packs > 0:
394+
padding_sample = self._create_padding_sample(
395+
pack_max_length,
396+
pack_max_length,
397+
is_qwen3_vl=is_qwen3_vl,
398+
has_rollout_routed_experts=has_rollout_routed_experts,
399+
has_rollout_logprobs=has_rollout_logprobs,
400+
n_routed_experts=n_routed_experts,
401+
)
402+
padding_samples = [padding_sample for _ in range(num_padding_packs)]
403+
packed_data_batches[dp_rank][step_idx].extend(padding_samples)
250404

251405
handles = []
252406
for worker_idx, worker in enumerate(self.workers):
253407
handles.append(
254408
worker.fit.remote( # type: ignore[attr-defined]
255-
data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size],
409+
data_batches=packed_data_batches[worker_idx // self.data_replicate_size],
256410
rollout_idx=rollout_idx,
257411
)
258412
)

0 commit comments

Comments
 (0)