Skip to content

Commit 08df53a

Browse files
authored
design draft for rl components API (#1477)
* first version * modify based on refactor_rollout_demo * add more comments * move generate_group from Env to Agent * rename to RolloutState and Environment to be same with doc
1 parent e368d87 commit 08df53a

File tree

1 file changed

+285
-0
lines changed

1 file changed

+285
-0
lines changed

design/component_rl.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
################################### imports ######################################
2+
from typing import Any, Callable
3+
import asyncio
4+
from enum import Enum
5+
from torch.utils.data import DataLoader
6+
import threading
7+
8+
from xtuner.v1.ray.rollout.controller import SampleParams
9+
from xtuner.v1.data_proto.rl_data import SampleParams # TODO: 删掉一个?
10+
from xtuner.v1.data_proto.sequence_context import SequenceContext
11+
from xtuner.v1.loss.base_loss_ctx import BaseLossContext
12+
13+
def load_tokenizer(hf_checkpoint, trust_remote_code=True): ...
14+
def load_processor(hf_checkpoint, trust_remote_code=True): ...
15+
16+
class PlacementGroup: ...
17+
18+
def log_metrics(metrics: dict): ...
19+
20+
class TrainItem:
21+
seq_ctx: SequenceContext
22+
loss_ctx: BaseLossContext
23+
24+
25+
################################### Main components ######################################
26+
class Status(Enum):
27+
INIT = "init"
28+
COMPLETED = "completed"
29+
ABORTED = "aborted"
30+
FAILED = "failed"
31+
ARCHIVED = "archived"
32+
EXPIRED = "expired"
33+
SKIPPED = "skipped"
34+
35+
36+
class RolloutState: # RolloutState:
37+
# message: list
38+
tokens: list[int] # 每一次实际输入
39+
40+
uid: int
41+
session_id: int | None = None
42+
prompt_ids: list[int]
43+
response: str
44+
response_ids: list[int] # 每一次实际输出,覆盖写
45+
logprobs: list[float]
46+
routed_experts: list[int] | None = None
47+
reward: float | list[float] | list[dict] | None = None
48+
loss_mask: list[int] | None = None # tokens + response_ids的长度
49+
state: Status = Status.INIT
50+
sample_parms: SampleParams | None = None
51+
tools: list | None = None
52+
tool_choice: str | None = None
53+
mm_infer_info: dict[str, Any]
54+
mm_train_info: dict[str, Any]
55+
finish_reason: str | None = None
56+
staleness: int = 0
57+
58+
59+
class RolloutController:
60+
async def generate_sample(self, sample: RolloutState) -> RolloutState: ...
61+
62+
63+
class Judge:
64+
def judge(self, sample: RolloutState) -> RolloutState: ...
65+
66+
67+
class Agent: # Agent负责一条轨迹样本的生成
68+
async def generate_sample(self, sample: RolloutState) -> RolloutState: ...
69+
async def generate_group(self, sample_fn: Callable[[], list[RolloutState]], data_mgr: "DataManager") -> list[RolloutState]: ...
70+
71+
class SingleTurnAgent(Agent):
72+
def __init__(self, rollout_ctl: RolloutController, hf_checkpoint, sample_params=SampleParams(), judge_cfg: dict = None) -> None:
73+
# persistent state for the generation process
74+
self.rollout_ctl = rollout_ctl
75+
self.hf_checkpoint = hf_checkpoint
76+
self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True)
77+
self.processor = load_processor(hf_checkpoint, trust_remote_code=True)
78+
self.sample_params = sample_params
79+
self.judge = Judge() if judge_cfg is not None else None
80+
81+
async def generate_sample(self, sample: RolloutState) -> RolloutState:
82+
sample = await self.rollout_ctl.generate_sample(sample)
83+
if self.judge is not None:
84+
sample = self.judge.judge(sample)
85+
return sample
86+
87+
async def generate_group(self, sample_fn: Callable[[], list[RolloutState]], data_mgr: "DataManager") -> Status:
88+
pending_tasks = []
89+
90+
group_samples: list[RolloutState] = sample_fn() # list of prompt_k Sample
91+
for sample in group_samples:
92+
task = asyncio.create_task(self.generate_sample(sample))
93+
pending_tasks.append(task)
94+
95+
generated_samples = asyncio.gather(*pending_tasks)
96+
97+
group_samples = await generated_samples
98+
data_mgr.add_to_replay_buffer(group_samples)
99+
return Status.COMPLETED
100+
101+
102+
class MultiTurnAgent(Agent):
103+
...
104+
105+
106+
class MultiTurnToolAgent(Agent):
107+
...
108+
109+
110+
class DataManager:
111+
dataloader: DataLoader
112+
replay_buffer: list[list[RolloutState]]
113+
114+
def sample_from_dataset(self) -> list[RolloutState]: ... # get from dataloader
115+
116+
def add_to_replay_buffer(self, samples: list[RolloutState]): ...
117+
118+
def get_batch(self) -> list[TrainItem]: ... # get from replay_buffer and convert to TrainItem
119+
120+
class ProduceStrategy: # Scheduler负责调度多个样本的生成,里面可以有超发、异步、重排长短样本等优化
121+
async def produce_batch(self, batch_size: int, data_mgr: DataManager, agent: Agent): ...
122+
123+
class SyncProduceStrategy(ProduceStrategy):
124+
async def produce_batch(self, batch_size: int, data_mgr: DataManager, agent: Agent):
125+
data_concurrency = batch_size
126+
127+
pending_tasks = []
128+
for _ in range(data_concurrency):
129+
task = asyncio.create_task(agent.generate_group(data_mgr.sample_from_dataset, data_mgr))
130+
pending_tasks.append(task)
131+
132+
completed_sample_count = 0
133+
while completed_sample_count < data_concurrency:
134+
if not pending_tasks:
135+
print("All tasks are done but not enough samples collected.")
136+
break
137+
done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED)
138+
for task in done_tasks:
139+
try:
140+
status: Status = await task
141+
if status == Status.COMPLETED:
142+
completed_sample_count += 1
143+
except Exception as e:
144+
print(f"Error in generating trajectory: {e}")
145+
146+
147+
class AsyncProduceStrategy(ProduceStrategy):
148+
def __init__(
149+
self,
150+
staleness_threshold: float = 0.0,
151+
enable_partial_rollout: bool = False,
152+
tail_batch_trigger_size: int = 0,
153+
tail_batch_candidate_step: int = 0,
154+
):
155+
class _Buffer: ...
156+
self.buffer = _Buffer(enable_partial_rollout, tail_batch_candidate_step, tail_batch_trigger_size)
157+
158+
async def produce_batch(self, batch_size: int, data_mgr: DataManager, agent: Agent):
159+
# hack sample_fn from data_mgr.sample_from_dataset and self.buffer.sample()
160+
pass
161+
162+
163+
class Environment:
164+
def __init__(self, rollout_ctl: RolloutController):
165+
self._agent: Agent = SingleTurnAgent(rollout_ctl)
166+
self._scheduler: ProduceStrategy = SyncProduceStrategy()
167+
168+
async def produce_batch(self, data_mgr: DataManager, batch_size: int):
169+
await self._scheduler.produce_batch(batch_size, data_mgr, self._agent)
170+
171+
def produce_loop(self, data_mgr: DataManager):
172+
pass
173+
174+
175+
class TrainController:
176+
# high level API
177+
def fit(self, batch: list[TrainItem]) -> dict: ...
178+
# low level API
179+
def compute_old_logprobs(self, batch: list[TrainItem]) -> list[TrainItem]: ...
180+
def compute_ref_logprobs(self, batch: list[TrainItem]) -> list[TrainItem]: ...
181+
def compute_values(self, batch: list[TrainItem]) -> list[TrainItem]: ...
182+
def compute_advantages(self, batch: list[TrainItem]) -> list[TrainItem]: ...
183+
def train(self, batch: list[TrainItem]) -> dict: ...
184+
def sync_weights(self, rollout_ctl: RolloutController): ...
185+
186+
187+
class Evaluator: # 根据rollout输出的batch,计算评估指标。本身并不负责rollout。
188+
def evaluate(self, batch: list[RolloutState]) -> dict: ...
189+
190+
191+
################################### Usage example with components #########################################
192+
# 弱化Trainer:Trainer中代码尽量少,尽量用componet来组织代码。下面是几种典型Trainer的组织方式。
193+
194+
def main_colocate_with_train_highlevel():
195+
# rollout_ctl, train_ctl, data_mgr, env, evaluator等对象都是主进程中本地对象,并不是ray actor。这样:
196+
# 1. 保证一大部分的数据传递无需跨机传输,方便统一管理
197+
# 2. 减少ray引入的debug和维护难度
198+
pg: PlacementGroup
199+
rollout_ctl: RolloutController(pg)
200+
train_ctl: TrainController(pg)
201+
202+
data_mgr: DataManager
203+
env: Environment(rollout_ctl)
204+
eval_data_mgr: DataManager
205+
evaluator: Evaluator
206+
total_rollouts: int
207+
208+
for i in range(total_rollouts):
209+
env.produce_batch(data_mgr)
210+
211+
train_batch: list[TrainItem] = data_mgr.get_batch()
212+
metrics = train_ctl.fit(train_batch)
213+
log_metrics(metrics)
214+
215+
train_ctl.sync_weights(rollout_ctl)
216+
217+
env.produce_batch(eval_data_mgr)
218+
eval_metrics = evaluator.evaluate(eval_data_mgr.get_batch())
219+
log_metrics(eval_metrics)
220+
221+
222+
class Packer:
223+
def pack_pad_dispatch(self, samples: list[RolloutState]) -> list[RolloutState]: ...
224+
225+
def main_colocate_with_train_lowlevel():
226+
data_mgr: DataManager
227+
pg: PlacementGroup
228+
rollout_ctl: RolloutController(pg)
229+
env: Environment(rollout_ctl)
230+
train_ctl: TrainController(pg)
231+
232+
eval_data_mgr: DataManager
233+
evaluator: Evaluator
234+
total_rollouts: int
235+
236+
for i in range(total_rollouts):
237+
env.produce_batch(data_mgr)
238+
239+
batch: list[TrainItem] = data_mgr.get_batch()
240+
241+
# below is equivalent to train_ctl.fit(batch)
242+
batch = Packer.pack_pad_dispatch(batch)
243+
batch = train_ctl.compute_old_logprobs(batch)
244+
batch = train_ctl.compute_ref_logprobs(batch)
245+
batch = train_ctl.compute_values(batch)
246+
batch = train_ctl.compute_advantages(batch) # TODO: AdvEstimator
247+
metrics = train_ctl.train(batch)
248+
249+
log_metrics(metrics)
250+
251+
train_ctl.sync_weights(rollout_ctl)
252+
253+
env.produce_batch(eval_data_mgr)
254+
eval_metrics = evaluator.evaluate(eval_data_mgr.get_batch())
255+
log_metrics(eval_metrics)
256+
257+
258+
def main_separate():
259+
data_mgr: DataManager
260+
pg1: PlacementGroup
261+
rollout_ctl: RolloutController(pg1)
262+
pg1_2: PlacementGroup
263+
rollout_ctl_2: RolloutController(pg1_2)
264+
env: Environment(rollout_ctl, rollout_ctl_2)
265+
266+
pg2: PlacementGroup
267+
train_ctl: TrainController(pg2)
268+
269+
eval_data_mgr: DataManager
270+
evaluator: Evaluator
271+
272+
producer_thread = threading.Thread(target=env.produce_loop, args=(data_mgr,))
273+
producer_thread.start()
274+
275+
total_rollouts: int
276+
for i in range(total_rollouts):
277+
batch: list[TrainItem] = data_mgr.get_batch()
278+
metrics = train_ctl.fit(batch)
279+
log_metrics(metrics)
280+
281+
train_ctl.sync_weights(rollout_ctl)
282+
283+
env.produce_batch(eval_data_mgr) # 优先级高于env.produce_loop
284+
eval_metrics = evaluator.evaluate(eval_data_mgr.get_batch())
285+
log_metrics(eval_metrics)

0 commit comments

Comments
 (0)