|
| 1 | +################################### imports ###################################### |
| 2 | +from typing import Any, Callable |
| 3 | +import asyncio |
| 4 | +from enum import Enum |
| 5 | +from torch.utils.data import DataLoader |
| 6 | +import threading |
| 7 | + |
| 8 | +from xtuner.v1.ray.rollout.controller import SampleParams |
| 9 | +from xtuner.v1.data_proto.rl_data import SampleParams # TODO: 删掉一个? |
| 10 | +from xtuner.v1.data_proto.sequence_context import SequenceContext |
| 11 | +from xtuner.v1.loss.base_loss_ctx import BaseLossContext |
| 12 | + |
| 13 | +def load_tokenizer(hf_checkpoint, trust_remote_code=True): ... |
| 14 | +def load_processor(hf_checkpoint, trust_remote_code=True): ... |
| 15 | + |
| 16 | +class PlacementGroup: ... |
| 17 | + |
| 18 | +def log_metrics(metrics: dict): ... |
| 19 | + |
| 20 | +class TrainItem: |
| 21 | + seq_ctx: SequenceContext |
| 22 | + loss_ctx: BaseLossContext |
| 23 | + |
| 24 | + |
| 25 | +################################### Main components ###################################### |
| 26 | +class Status(Enum): |
| 27 | + INIT = "init" |
| 28 | + COMPLETED = "completed" |
| 29 | + ABORTED = "aborted" |
| 30 | + FAILED = "failed" |
| 31 | + ARCHIVED = "archived" |
| 32 | + EXPIRED = "expired" |
| 33 | + SKIPPED = "skipped" |
| 34 | + |
| 35 | + |
| 36 | +class RolloutState: # RolloutState: |
| 37 | + # message: list |
| 38 | + tokens: list[int] # 每一次实际输入 |
| 39 | + |
| 40 | + uid: int |
| 41 | + session_id: int | None = None |
| 42 | + prompt_ids: list[int] |
| 43 | + response: str |
| 44 | + response_ids: list[int] # 每一次实际输出,覆盖写 |
| 45 | + logprobs: list[float] |
| 46 | + routed_experts: list[int] | None = None |
| 47 | + reward: float | list[float] | list[dict] | None = None |
| 48 | + loss_mask: list[int] | None = None # tokens + response_ids的长度 |
| 49 | + state: Status = Status.INIT |
| 50 | + sample_parms: SampleParams | None = None |
| 51 | + tools: list | None = None |
| 52 | + tool_choice: str | None = None |
| 53 | + mm_infer_info: dict[str, Any] |
| 54 | + mm_train_info: dict[str, Any] |
| 55 | + finish_reason: str | None = None |
| 56 | + staleness: int = 0 |
| 57 | + |
| 58 | + |
| 59 | +class RolloutController: |
| 60 | + async def generate_sample(self, sample: RolloutState) -> RolloutState: ... |
| 61 | + |
| 62 | + |
| 63 | +class Judge: |
| 64 | + def judge(self, sample: RolloutState) -> RolloutState: ... |
| 65 | + |
| 66 | + |
| 67 | +class Agent: # Agent负责一条轨迹样本的生成 |
| 68 | + async def generate_sample(self, sample: RolloutState) -> RolloutState: ... |
| 69 | + async def generate_group(self, sample_fn: Callable[[], list[RolloutState]], data_mgr: "DataManager") -> list[RolloutState]: ... |
| 70 | + |
| 71 | +class SingleTurnAgent(Agent): |
| 72 | + def __init__(self, rollout_ctl: RolloutController, hf_checkpoint, sample_params=SampleParams(), judge_cfg: dict = None) -> None: |
| 73 | + # persistent state for the generation process |
| 74 | + self.rollout_ctl = rollout_ctl |
| 75 | + self.hf_checkpoint = hf_checkpoint |
| 76 | + self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True) |
| 77 | + self.processor = load_processor(hf_checkpoint, trust_remote_code=True) |
| 78 | + self.sample_params = sample_params |
| 79 | + self.judge = Judge() if judge_cfg is not None else None |
| 80 | + |
| 81 | + async def generate_sample(self, sample: RolloutState) -> RolloutState: |
| 82 | + sample = await self.rollout_ctl.generate_sample(sample) |
| 83 | + if self.judge is not None: |
| 84 | + sample = self.judge.judge(sample) |
| 85 | + return sample |
| 86 | + |
| 87 | + async def generate_group(self, sample_fn: Callable[[], list[RolloutState]], data_mgr: "DataManager") -> Status: |
| 88 | + pending_tasks = [] |
| 89 | + |
| 90 | + group_samples: list[RolloutState] = sample_fn() # list of prompt_k Sample |
| 91 | + for sample in group_samples: |
| 92 | + task = asyncio.create_task(self.generate_sample(sample)) |
| 93 | + pending_tasks.append(task) |
| 94 | + |
| 95 | + generated_samples = asyncio.gather(*pending_tasks) |
| 96 | + |
| 97 | + group_samples = await generated_samples |
| 98 | + data_mgr.add_to_replay_buffer(group_samples) |
| 99 | + return Status.COMPLETED |
| 100 | + |
| 101 | + |
| 102 | +class MultiTurnAgent(Agent): |
| 103 | + ... |
| 104 | + |
| 105 | + |
| 106 | +class MultiTurnToolAgent(Agent): |
| 107 | + ... |
| 108 | + |
| 109 | + |
| 110 | +class DataManager: |
| 111 | + dataloader: DataLoader |
| 112 | + replay_buffer: list[list[RolloutState]] |
| 113 | + |
| 114 | + def sample_from_dataset(self) -> list[RolloutState]: ... # get from dataloader |
| 115 | + |
| 116 | + def add_to_replay_buffer(self, samples: list[RolloutState]): ... |
| 117 | + |
| 118 | + def get_batch(self) -> list[TrainItem]: ... # get from replay_buffer and convert to TrainItem |
| 119 | + |
| 120 | +class ProduceStrategy: # Scheduler负责调度多个样本的生成,里面可以有超发、异步、重排长短样本等优化 |
| 121 | + async def produce_batch(self, batch_size: int, data_mgr: DataManager, agent: Agent): ... |
| 122 | + |
| 123 | +class SyncProduceStrategy(ProduceStrategy): |
| 124 | + async def produce_batch(self, batch_size: int, data_mgr: DataManager, agent: Agent): |
| 125 | + data_concurrency = batch_size |
| 126 | + |
| 127 | + pending_tasks = [] |
| 128 | + for _ in range(data_concurrency): |
| 129 | + task = asyncio.create_task(agent.generate_group(data_mgr.sample_from_dataset, data_mgr)) |
| 130 | + pending_tasks.append(task) |
| 131 | + |
| 132 | + completed_sample_count = 0 |
| 133 | + while completed_sample_count < data_concurrency: |
| 134 | + if not pending_tasks: |
| 135 | + print("All tasks are done but not enough samples collected.") |
| 136 | + break |
| 137 | + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED) |
| 138 | + for task in done_tasks: |
| 139 | + try: |
| 140 | + status: Status = await task |
| 141 | + if status == Status.COMPLETED: |
| 142 | + completed_sample_count += 1 |
| 143 | + except Exception as e: |
| 144 | + print(f"Error in generating trajectory: {e}") |
| 145 | + |
| 146 | + |
| 147 | +class AsyncProduceStrategy(ProduceStrategy): |
| 148 | + def __init__( |
| 149 | + self, |
| 150 | + staleness_threshold: float = 0.0, |
| 151 | + enable_partial_rollout: bool = False, |
| 152 | + tail_batch_trigger_size: int = 0, |
| 153 | + tail_batch_candidate_step: int = 0, |
| 154 | + ): |
| 155 | + class _Buffer: ... |
| 156 | + self.buffer = _Buffer(enable_partial_rollout, tail_batch_candidate_step, tail_batch_trigger_size) |
| 157 | + |
| 158 | + async def produce_batch(self, batch_size: int, data_mgr: DataManager, agent: Agent): |
| 159 | + # hack sample_fn from data_mgr.sample_from_dataset and self.buffer.sample() |
| 160 | + pass |
| 161 | + |
| 162 | + |
| 163 | +class Environment: |
| 164 | + def __init__(self, rollout_ctl: RolloutController): |
| 165 | + self._agent: Agent = SingleTurnAgent(rollout_ctl) |
| 166 | + self._scheduler: ProduceStrategy = SyncProduceStrategy() |
| 167 | + |
| 168 | + async def produce_batch(self, data_mgr: DataManager, batch_size: int): |
| 169 | + await self._scheduler.produce_batch(batch_size, data_mgr, self._agent) |
| 170 | + |
| 171 | + def produce_loop(self, data_mgr: DataManager): |
| 172 | + pass |
| 173 | + |
| 174 | + |
| 175 | +class TrainController: |
| 176 | + # high level API |
| 177 | + def fit(self, batch: list[TrainItem]) -> dict: ... |
| 178 | + # low level API |
| 179 | + def compute_old_logprobs(self, batch: list[TrainItem]) -> list[TrainItem]: ... |
| 180 | + def compute_ref_logprobs(self, batch: list[TrainItem]) -> list[TrainItem]: ... |
| 181 | + def compute_values(self, batch: list[TrainItem]) -> list[TrainItem]: ... |
| 182 | + def compute_advantages(self, batch: list[TrainItem]) -> list[TrainItem]: ... |
| 183 | + def train(self, batch: list[TrainItem]) -> dict: ... |
| 184 | + def sync_weights(self, rollout_ctl: RolloutController): ... |
| 185 | + |
| 186 | + |
| 187 | +class Evaluator: # 根据rollout输出的batch,计算评估指标。本身并不负责rollout。 |
| 188 | + def evaluate(self, batch: list[RolloutState]) -> dict: ... |
| 189 | + |
| 190 | + |
| 191 | +################################### Usage example with components ######################################### |
| 192 | +# 弱化Trainer:Trainer中代码尽量少,尽量用componet来组织代码。下面是几种典型Trainer的组织方式。 |
| 193 | + |
| 194 | +def main_colocate_with_train_highlevel(): |
| 195 | + # rollout_ctl, train_ctl, data_mgr, env, evaluator等对象都是主进程中本地对象,并不是ray actor。这样: |
| 196 | + # 1. 保证一大部分的数据传递无需跨机传输,方便统一管理 |
| 197 | + # 2. 减少ray引入的debug和维护难度 |
| 198 | + pg: PlacementGroup |
| 199 | + rollout_ctl: RolloutController(pg) |
| 200 | + train_ctl: TrainController(pg) |
| 201 | + |
| 202 | + data_mgr: DataManager |
| 203 | + env: Environment(rollout_ctl) |
| 204 | + eval_data_mgr: DataManager |
| 205 | + evaluator: Evaluator |
| 206 | + total_rollouts: int |
| 207 | + |
| 208 | + for i in range(total_rollouts): |
| 209 | + env.produce_batch(data_mgr) |
| 210 | + |
| 211 | + train_batch: list[TrainItem] = data_mgr.get_batch() |
| 212 | + metrics = train_ctl.fit(train_batch) |
| 213 | + log_metrics(metrics) |
| 214 | + |
| 215 | + train_ctl.sync_weights(rollout_ctl) |
| 216 | + |
| 217 | + env.produce_batch(eval_data_mgr) |
| 218 | + eval_metrics = evaluator.evaluate(eval_data_mgr.get_batch()) |
| 219 | + log_metrics(eval_metrics) |
| 220 | + |
| 221 | + |
| 222 | +class Packer: |
| 223 | + def pack_pad_dispatch(self, samples: list[RolloutState]) -> list[RolloutState]: ... |
| 224 | + |
| 225 | +def main_colocate_with_train_lowlevel(): |
| 226 | + data_mgr: DataManager |
| 227 | + pg: PlacementGroup |
| 228 | + rollout_ctl: RolloutController(pg) |
| 229 | + env: Environment(rollout_ctl) |
| 230 | + train_ctl: TrainController(pg) |
| 231 | + |
| 232 | + eval_data_mgr: DataManager |
| 233 | + evaluator: Evaluator |
| 234 | + total_rollouts: int |
| 235 | + |
| 236 | + for i in range(total_rollouts): |
| 237 | + env.produce_batch(data_mgr) |
| 238 | + |
| 239 | + batch: list[TrainItem] = data_mgr.get_batch() |
| 240 | + |
| 241 | + # below is equivalent to train_ctl.fit(batch) |
| 242 | + batch = Packer.pack_pad_dispatch(batch) |
| 243 | + batch = train_ctl.compute_old_logprobs(batch) |
| 244 | + batch = train_ctl.compute_ref_logprobs(batch) |
| 245 | + batch = train_ctl.compute_values(batch) |
| 246 | + batch = train_ctl.compute_advantages(batch) # TODO: AdvEstimator |
| 247 | + metrics = train_ctl.train(batch) |
| 248 | + |
| 249 | + log_metrics(metrics) |
| 250 | + |
| 251 | + train_ctl.sync_weights(rollout_ctl) |
| 252 | + |
| 253 | + env.produce_batch(eval_data_mgr) |
| 254 | + eval_metrics = evaluator.evaluate(eval_data_mgr.get_batch()) |
| 255 | + log_metrics(eval_metrics) |
| 256 | + |
| 257 | + |
| 258 | +def main_separate(): |
| 259 | + data_mgr: DataManager |
| 260 | + pg1: PlacementGroup |
| 261 | + rollout_ctl: RolloutController(pg1) |
| 262 | + pg1_2: PlacementGroup |
| 263 | + rollout_ctl_2: RolloutController(pg1_2) |
| 264 | + env: Environment(rollout_ctl, rollout_ctl_2) |
| 265 | + |
| 266 | + pg2: PlacementGroup |
| 267 | + train_ctl: TrainController(pg2) |
| 268 | + |
| 269 | + eval_data_mgr: DataManager |
| 270 | + evaluator: Evaluator |
| 271 | + |
| 272 | + producer_thread = threading.Thread(target=env.produce_loop, args=(data_mgr,)) |
| 273 | + producer_thread.start() |
| 274 | + |
| 275 | + total_rollouts: int |
| 276 | + for i in range(total_rollouts): |
| 277 | + batch: list[TrainItem] = data_mgr.get_batch() |
| 278 | + metrics = train_ctl.fit(batch) |
| 279 | + log_metrics(metrics) |
| 280 | + |
| 281 | + train_ctl.sync_weights(rollout_ctl) |
| 282 | + |
| 283 | + env.produce_batch(eval_data_mgr) # 优先级高于env.produce_loop |
| 284 | + eval_metrics = evaluator.evaluate(eval_data_mgr.get_batch()) |
| 285 | + log_metrics(eval_metrics) |
0 commit comments