From 409b09eae696e4b35d447aca82ccdb24111c40ba Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Fri, 30 Jan 2026 13:11:41 +0000 Subject: [PATCH 1/6] add rl demo --- examples/v2/gsm8k_env_demo.py | 68 ++++++++++++++++++++++++++++++ xtuner/v2/base_env_runner.py | 75 +++++++++++++++++++++++++++++++++ xtuner/v2/rollout_controller.py | 31 ++++++++++++++ xtuner/v2/rollout_state.py | 67 +++++++++++++++++++++++++++++ xtuner/v2/utils.py | 11 +++++ 5 files changed, 252 insertions(+) create mode 100644 examples/v2/gsm8k_env_demo.py create mode 100644 xtuner/v2/base_env_runner.py create mode 100644 xtuner/v2/rollout_controller.py create mode 100644 xtuner/v2/rollout_state.py create mode 100644 xtuner/v2/utils.py diff --git a/examples/v2/gsm8k_env_demo.py b/examples/v2/gsm8k_env_demo.py new file mode 100644 index 000000000..8490c1d4e --- /dev/null +++ b/examples/v2/gsm8k_env_demo.py @@ -0,0 +1,68 @@ +import ray +import os +from uuid import uuid4 + +os.environ["XTUNER_USE_LMDEPLOY"] = "1" + +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v2.rollout_controller import RolloutController +from xtuner.v2.base_env_runner import BaseEnvRunner +from xtuner.v2.rollout_state import ProcessorUtilState, RolloutState +from xtuner.v1.ray.rollout.controller import SampleParams +from xtuner.v1.ray.judger.gsm8k import compute_reward + + +if __name__ == '__main__': + ray.init(num_cpus=80, ignore_reinit_error=True) + + # model_path = '/mnt/shared-storage-user/llmrazor-share/model/intern-s1-mini-hha-fix_tokenizer' + model_path ='/mnt/shared-storage-user/llmrazor-share/model/Qwen3-8B' + + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=1, # 1 or 8 + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB + ) + pg = AutoAcceleratorWorkers.build_placement_group(resources) + + # 2. rollout + rollout_config = RolloutConfig( + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=1, + expert_parallel_size=1, + gpu_memory_utilization=0.75 + ) + + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + + async def gsm8k_generate(data_item, processor_utils_state: ProcessorUtilState, rollout_controller: RolloutController, judger): + input_ids = processor_utils_state.tokenizer.apply_chat_template(data_item["prompt"], return_tensors="pt")["input_ids"][0].tolist() + rollout_state = RolloutState(uid=uuid4().int, input_ids=input_ids) + rollout_state = await rollout_controller.generate.remote(rollout_state) + + # reward = compute_reward(processor_utils_state, data_item, rollout_state) + # rollout_state.rewards = reward + return rollout_state + + processor_utils_state = ProcessorUtilState(hf_checkpoint=model_path, sample_params=SampleParams()) + gsm8k_env_runner = ray.remote(BaseEnvRunner).remote(rollout_controller, processor_utils_state=processor_utils_state, generate_external=gsm8k_generate) + + # prompt = [{"content": [{"image_url": {"image_wh": [404, 162], + # "url": "images/test_2.jpg"}, + # "type": "image_url"}, + # { + # "text": "Find the area of the figure to the nearest tenth. You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \\boxed{}.", + # "type": "text"}], + # "role": "user"}] + # extra_info = {'media_root': '/mnt/shared-storage-user/llmrazor-share/data/geometry3k/'} + + prompt = [{"role": "user", "content": 'Calculate 13+24=', "type": "text"}] + data_item = {'prompt': prompt} + + res1 = ray.get(gsm8k_env_runner.generate.remote(data_item)) + print("Response from SGLang infer:", res1) + ray.get(rollout_controller.shutdown.remote(), timeout=300) diff --git a/xtuner/v2/base_env_runner.py b/xtuner/v2/base_env_runner.py new file mode 100644 index 000000000..3d2f2bdc2 --- /dev/null +++ b/xtuner/v2/base_env_runner.py @@ -0,0 +1,75 @@ + +import asyncio +from typing import AsyncIterator, Optional +from xtuner.v1.datasets import DataloaderConfig + +from .utils import load_function +from .rollout_controller import RolloutController +from .rollout_state import Trajectory, ProcessorUtilState + + +class BaseEnvRunner: + def __init__(self, + rollout_controller: RolloutController, + processor_utils_state: ProcessorUtilState, + dataloader_cfg: DataloaderConfig | None = None, # none 是为了这个 envruner 可以独立运行 + judger: callable | None = None, # none 是为了这个 envruner 可以独立运行 + generate_external: callable | None = None, + ): + + self.dataloader = None + if dataloader_cfg is not None: + self.dataloader = dataloader_cfg.build() + self.rollout_controller = rollout_controller + self.judger = judger + self.processor_utils_state = processor_utils_state + + self.generate_external = generate_external + if self.generate_external is not None: + self.generate_external = load_function(self.generate_external) + + def sample(self) -> dict: + try: + data = next(self.dataloader_iter)[0] + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + self.reduced_consumed_samples += 1 + return data + + async def generate(self, data_item) -> Optional[Trajectory]: + if self.generate_external is not None: + return await self.generate_external(data_item, self.processor_utils_state, self.rollout_controller, self.judger) + else: + raise NotImplementedError + + async def generate_batch(self, batch_size: int) -> AsyncIterator[Trajectory]: + data_concurrency = batch_size + assert self.dataloader is not None, "Dataloader must be provided for batch generation." + + pending_tasks = [] + for _ in range(data_concurrency): + data_item = self.sample() + task = asyncio.create_task(self.generate(data_item)) + pending_tasks.append(task) + + completed_sample_count = 0 + batch_trajectories = [] + while completed_sample_count < batch_size: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED) + for task in done_tasks: + try: + traj = await task + if traj is not None: + batch_trajectories.append(traj) + completed_sample_count += 1 + except Exception as e: + print(f"Error in generating trajectory: {e}") + + # TODO: 如果有超发 + return batch_trajectories diff --git a/xtuner/v2/rollout_controller.py b/xtuner/v2/rollout_controller.py new file mode 100644 index 000000000..7f31ec23b --- /dev/null +++ b/xtuner/v2/rollout_controller.py @@ -0,0 +1,31 @@ +from xtuner.v1.ray.rollout import RolloutController as V1RolloutController +from .rollout_state import RolloutState, Status + +reason_map = { + "length": Status.COMPLETED, + 'aborted': Status.ABORTED, + "failed": Status.FAILED, +} + +# 临时方案 +class RolloutController(V1RolloutController): + + async def generate(self, rollout_state: RolloutState): + + # 简单包一层 + input_ids = rollout_state.input_ids + sample_params = rollout_state.sample_params + session_id = rollout_state.session_id + + response = await super().rollout( + input_ids=input_ids, + sample_params=sample_params, + session_id=session_id + ) + + rollout_state.response = response.response + rollout_state.response_ids = response.response_ids + rollout_state.logprobs = response.logprobs + rollout_state.state = response.state + + return rollout_state diff --git a/xtuner/v2/rollout_state.py b/xtuner/v2/rollout_state.py new file mode 100644 index 000000000..35d7d6695 --- /dev/null +++ b/xtuner/v2/rollout_state.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from enum import Enum +from xtuner.v1.data_proto.rl_data import SampleParams +from transformers import AutoTokenizer, AutoProcessor, PreTrainedTokenizerBase +from transformers.image_processing_utils import ProcessorMixin +from xtuner.v1.ray.rollout.controller import SampleParams + + +class Status(Enum): + INIT = "init" + COMPLETED = "completed" + ABORTED = "aborted" + FAILED = "failed" + ARCHIVED = "archived" + EXPIRED = "expired" + SKIPPED = "skipped" + + +@dataclass +class RolloutState: + uid: int + session_id: int | None = None + prompt_ids: list[int] + response: str + response_ids: list[int] + logprobs: list[float] + routed_experts: list[int] | None = None + state: Status = Status.INIT + sample_parms: SampleParams | None = None + tools: list | None = None + tool_choice: str | None = None + + +@dataclass +class Trajectory: + uid: str + env: str + rollout_state: RolloutState | list[RolloutState] + reward: float | list[float] | list[dict] + + +def load_tokenizer(name_or_path: str, **kwargs): + return AutoTokenizer.from_pretrained(name_or_path, **kwargs) + + +def load_processor(name_or_path: str, **kwargs): + try: + proc = AutoProcessor.from_pretrained(name_or_path, **kwargs) + except (OSError, ValueError) as e: + proc = None + + # If HF returned a tokenizer, discard it. + if isinstance(proc, PreTrainedTokenizerBase) or not isinstance(proc, ProcessorMixin): + proc = None + + return proc + + +# TODO: 重命名 +# 必须要有一个随着 RolloutState 一起流转的类,否则无法满足扩展性 +class ProcessorUtilState: + def __init__(self, hf_checkpoint, sample_params=SampleParams()) -> None: + # persistent state for the generation process + self.hf_checkpoint = hf_checkpoint + self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(hf_checkpoint, trust_remote_code=True) + self.sample_params = sample_params diff --git a/xtuner/v2/utils.py b/xtuner/v2/utils.py new file mode 100644 index 000000000..05fde3426 --- /dev/null +++ b/xtuner/v2/utils.py @@ -0,0 +1,11 @@ +import importlib + +def load_function(path): + """ + Load a function from a module. + :param path: The path to the function, e.g. "module.submodule.function". + :return: The function object. + """ + module_path, _, attr = path.rpartition(".") + module = importlib.import_module(module_path) + return getattr(module, attr) From 1c572a2998f5c47020e74e4a7169bcff82b369f4 Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Sat, 31 Jan 2026 03:53:15 +0000 Subject: [PATCH 2/6] add tool call demo --- ...8k_env_demo.py => single_turn_env_demo.py} | 32 ++-- examples/v2/toolcall_env_demo.py | 151 ++++++++++++++++++ xtuner/v2/rollout_controller.py | 2 +- xtuner/v2/rollout_state.py | 12 +- ...ase_env_runner.py => simple_env_runner.py} | 44 +++-- 5 files changed, 213 insertions(+), 28 deletions(-) rename examples/v2/{gsm8k_env_demo.py => single_turn_env_demo.py} (72%) create mode 100644 examples/v2/toolcall_env_demo.py rename xtuner/v2/{base_env_runner.py => simple_env_runner.py} (60%) diff --git a/examples/v2/gsm8k_env_demo.py b/examples/v2/single_turn_env_demo.py similarity index 72% rename from examples/v2/gsm8k_env_demo.py rename to examples/v2/single_turn_env_demo.py index 8490c1d4e..509a3d630 100644 --- a/examples/v2/gsm8k_env_demo.py +++ b/examples/v2/single_turn_env_demo.py @@ -7,7 +7,7 @@ from xtuner.v1.ray.config.worker import RolloutConfig from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers from xtuner.v2.rollout_controller import RolloutController -from xtuner.v2.base_env_runner import BaseEnvRunner +from xtuner.v2.simple_env_runner import SimpleEnvRunner from xtuner.v2.rollout_state import ProcessorUtilState, RolloutState from xtuner.v1.ray.rollout.controller import SampleParams from xtuner.v1.ray.judger.gsm8k import compute_reward @@ -39,17 +39,8 @@ rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) - async def gsm8k_generate(data_item, processor_utils_state: ProcessorUtilState, rollout_controller: RolloutController, judger): - input_ids = processor_utils_state.tokenizer.apply_chat_template(data_item["prompt"], return_tensors="pt")["input_ids"][0].tolist() - rollout_state = RolloutState(uid=uuid4().int, input_ids=input_ids) - rollout_state = await rollout_controller.generate.remote(rollout_state) - - # reward = compute_reward(processor_utils_state, data_item, rollout_state) - # rollout_state.rewards = reward - return rollout_state - processor_utils_state = ProcessorUtilState(hf_checkpoint=model_path, sample_params=SampleParams()) - gsm8k_env_runner = ray.remote(BaseEnvRunner).remote(rollout_controller, processor_utils_state=processor_utils_state, generate_external=gsm8k_generate) + simple_env_runner = ray.remote(SimpleEnvRunner).remote(rollout_controller, processor_utils_state=processor_utils_state) # prompt = [{"content": [{"image_url": {"image_wh": [404, 162], # "url": "images/test_2.jpg"}, @@ -62,7 +53,22 @@ async def gsm8k_generate(data_item, processor_utils_state: ProcessorUtilState, r prompt = [{"role": "user", "content": 'Calculate 13+24=', "type": "text"}] data_item = {'prompt': prompt} + + input_ids = processor_utils_state.tokenizer.apply_chat_template(data_item["prompt"], return_tensors="pt")["input_ids"][0].tolist() + rollout_state = RolloutState(uid=uuid4().int, tokens=input_ids) - res1 = ray.get(gsm8k_env_runner.generate.remote(data_item)) - print("Response from SGLang infer:", res1) + # 生成单条 + res1 = ray.get(simple_env_runner.generate.remote(rollout_state)) + print("Response from infer:", res1) + + # 生成一组 + res1 = ray.get(simple_env_runner.generate_group.remote(rollout_state)) + print("Response from infer:", res1) + + # 生成多条 + res1 = ray.get(simple_env_runner.generate_batch.remote(batch_size=64)) + print("Response from infer:", res1) + ray.get(rollout_controller.shutdown.remote(), timeout=300) + + diff --git a/examples/v2/toolcall_env_demo.py b/examples/v2/toolcall_env_demo.py new file mode 100644 index 000000000..cc9696ee0 --- /dev/null +++ b/examples/v2/toolcall_env_demo.py @@ -0,0 +1,151 @@ +import ray +import os +from uuid import uuid4 +from copy import deepcopy + +os.environ["XTUNER_USE_LMDEPLOY"] = "1" + +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v2.rollout_controller import RolloutController +from xtuner.v2.simple_env_runner import SimpleEnvRunner +from xtuner.v2.rollout_state import ProcessorUtilState, RolloutState +from xtuner.v1.ray.rollout.controller import SampleParams +from xtuner.v1.ray.judger.gsm8k import compute_reward + + +TOOL_CONFIGS = { + "max_turns": 16, + "max_tool_calls": 16, + "tool_concurrency": 32, # Aggressive: 32 concurrent processes + # Python interpreter settings + "python_timeout": 120, # 2 minutes for complex calculations + "python_memory_limit": "4GB", # 4GB per Python process + "python_cpu_limit": 1, + # Memory management settings + "max_memory_usage": 12288, # 12GB total (75% of 16GB) + "cleanup_threshold": 6144, # 6GB + "aggressive_cleanup_threshold": 3072, # 3GB + "force_cleanup_threshold": 9216, # 9GB +} + +# 一旦自定义函数,partial rollout 不太好做。 +async def gsm8k_with_tools_generate(rollout_state: RolloutState, + processor_utils_state: ProcessorUtilState, + rollout_controller:RolloutController, + judger) -> RolloutState: + # 可以自己用 tokens,也可以用 message + tools + prompt = processor_utils_state.tokenizernizer.apply_chat_template( + rollout_state.messages, + tools=rollout_state.tools if hasattr(rollout_state, "tools") else None, + tokenize=False, + add_generation_prompt=True, + ) + prompt_tokens_ids = processor_utils_state.tokenizer.tokenize(prompt)["input_ids"] + + response_token_ids = [] + loss_masks = [] + tool_call_count = 0 # Track actual tool call rounds + rollout_log_probs=[] + init_rollout_state = deepcopy(rollout_state) + for turn in range(TOOL_CONFIGS["max_turns"]): + current_token_ids = prompt_tokens_ids + response_token_ids + rollout_state.tokens = current_token_ids + init_rollout_state.tokens = rollout_state.tokens + + # TODO: 需要注意 ray actor 数据传输数据量,原则上他不需要的不应该传入 + # 发送给 rollout_controller 的对象只需要追加 tokens 即可,其余内容他不考虑 + + # 相比于直接发送 post 请求,确实会麻烦一些,学习成本高一些 + # 但是由于有一层管理,少掉了给 n 个 url 然后自己路由的麻烦 + # TODO: 是否要换成 url post 方式? + rollout_state = await rollout_controller.rollout.remote(rollout_state) + + init_rollout_state.state = rollout_state.state + + response_token_ids += rollout_state.response_ids + rollout_log_probs += rollout_state.logprobs + loss_masks += [1] * len(rollout_state.response_ids) + + if rollout_state.state == RolloutState.State.ABORTED: + # rollout_state 内容转移到 init_rollout_state 中,他是全局保存的 + init_rollout_state.state = rollout_state.state + init_rollout_state.response_ids = response_token_ids + init_rollout_state.logprobs = rollout_log_probs + init_rollout_state.loss_mask = loss_masks + return init_rollout_state + + # 执行工具 + next_obs, done = await execute_predictions(rollout_state.response) + if done: + break + + obs_tokens_ids = processor_utils_state.tokenizer(next_obs, add_special_tokens=False)["input_ids"] + response += next_obs + response_token_ids += obs_tokens_ids + rollout_log_probs += [0.0] * len(obs_tokens_ids) + loss_masks += [0] * len(obs_tokens_ids) + + if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]: + break + + # rollout_state 内容转移到 init_rollout_state 中,他是全局保存的 + init_rollout_state.response_ids = response_token_ids + init_rollout_state.logprobs = rollout_log_probs + init_rollout_state.loss_mask = loss_masks + return init_rollout_state + + +if __name__ == '__main__': + + ray.init(num_cpus=80, ignore_reinit_error=True) + + # model_path = '/mnt/shared-storage-user/llmrazor-share/model/intern-s1-mini-hha-fix_tokenizer' + model_path ='/mnt/shared-storage-user/llmrazor-share/model/Qwen3-8B' + + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=1, # 1 or 8 + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB + ) + pg = AutoAcceleratorWorkers.build_placement_group(resources) + + # 2. rollout + rollout_config = RolloutConfig( + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=1, + expert_parallel_size=1, + gpu_memory_utilization=0.75 + ) + + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + + processor_utils_state = ProcessorUtilState(hf_checkpoint=model_path, sample_params=SampleParams()) + simple_env_runner = ray.remote(SimpleEnvRunner).remote(rollout_controller, + processor_utils_state=processor_utils_state, + generate_external=gsm8k_with_tools_generate) + + prompt = [{"role": "user", "content": 'Calculate 13+24=', "type": "text"}] + data_item = {'prompt': prompt} + + input_ids = processor_utils_state.tokenizer.apply_chat_template(data_item["prompt"], return_tensors="pt")["input_ids"][0].tolist() + rollout_state = RolloutState(uid=uuid4().int, tokens=input_ids) + + # 生成单条 + res1 = ray.get(simple_env_runner.generate.remote(rollout_state)) + print("Response from infer:", res1) + + # 生成一组 + res1 = ray.get(simple_env_runner.generate_group.remote(rollout_state)) + print("Response from infer:", res1) + + # 生成多条 + res1 = ray.get(simple_env_runner.generate_batch.remote(batch_size=64)) + print("Response from infer:", res1) + + ray.get(rollout_controller.shutdown.remote(), timeout=300) + + diff --git a/xtuner/v2/rollout_controller.py b/xtuner/v2/rollout_controller.py index 7f31ec23b..b945bf8b5 100644 --- a/xtuner/v2/rollout_controller.py +++ b/xtuner/v2/rollout_controller.py @@ -13,7 +13,7 @@ class RolloutController(V1RolloutController): async def generate(self, rollout_state: RolloutState): # 简单包一层 - input_ids = rollout_state.input_ids + input_ids = rollout_state.tokens sample_params = rollout_state.sample_params session_id = rollout_state.session_id diff --git a/xtuner/v2/rollout_state.py b/xtuner/v2/rollout_state.py index 35d7d6695..8b816524e 100644 --- a/xtuner/v2/rollout_state.py +++ b/xtuner/v2/rollout_state.py @@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoProcessor, PreTrainedTokenizerBase from transformers.image_processing_utils import ProcessorMixin from xtuner.v1.ray.rollout.controller import SampleParams +from dataclasses import field class Status(Enum): @@ -18,6 +19,10 @@ class Status(Enum): @dataclass class RolloutState: + # dataset 输出必须 + message: list + tokens: list[int] + uid: int session_id: int | None = None prompt_ids: list[int] @@ -25,18 +30,19 @@ class RolloutState: response_ids: list[int] logprobs: list[float] routed_experts: list[int] | None = None + reward: float | list[float] | list[dict] | None = None + loss_mask: list[int] | None = None state: Status = Status.INIT sample_parms: SampleParams | None = None tools: list | None = None tool_choice: str | None = None +# TODO: 这个对象存在的意义是啥?暂时不用,否则会导致内部循环对象不一致, partial rollout 也不好弄 @dataclass class Trajectory: - uid: str - env: str + env: str = 'default' rollout_state: RolloutState | list[RolloutState] - reward: float | list[float] | list[dict] def load_tokenizer(name_or_path: str, **kwargs): diff --git a/xtuner/v2/base_env_runner.py b/xtuner/v2/simple_env_runner.py similarity index 60% rename from xtuner/v2/base_env_runner.py rename to xtuner/v2/simple_env_runner.py index 3d2f2bdc2..434faadf0 100644 --- a/xtuner/v2/base_env_runner.py +++ b/xtuner/v2/simple_env_runner.py @@ -1,19 +1,19 @@ import asyncio -from typing import AsyncIterator, Optional from xtuner.v1.datasets import DataloaderConfig from .utils import load_function from .rollout_controller import RolloutController -from .rollout_state import Trajectory, ProcessorUtilState +from .rollout_state import ProcessorUtilState, RolloutState -class BaseEnvRunner: +# TODO:这个类做的东西有点多,是否需要加一个 base env runner +class SimpleEnvRunner: def __init__(self, rollout_controller: RolloutController, processor_utils_state: ProcessorUtilState, dataloader_cfg: DataloaderConfig | None = None, # none 是为了这个 envruner 可以独立运行 - judger: callable | None = None, # none 是为了这个 envruner 可以独立运行 + judger: callable | None = None, # none 是为了这个 envruner 可以独立运行, 可以是简单的 callable, 也可以是 actor worker generate_external: callable | None = None, ): @@ -23,6 +23,8 @@ def __init__(self, self.rollout_controller = rollout_controller self.judger = judger self.processor_utils_state = processor_utils_state + + self.prompt_repeat_k = 1 # 外面传入 self.generate_external = generate_external if self.generate_external is not None: @@ -39,20 +41,40 @@ def sample(self) -> dict: self.reduced_consumed_samples += 1 return data - async def generate(self, data_item) -> Optional[Trajectory]: + async def generate(self, rollout_state: RolloutState) -> RolloutState: if self.generate_external is not None: - return await self.generate_external(data_item, self.processor_utils_state, self.rollout_controller, self.judger) + # TODO: 如果走这个分支,估计没有走 partial rollout + return await self.generate_external(rollout_state, self.processor_utils_state, self.rollout_controller, self.judger) else: - raise NotImplementedError - - async def generate_batch(self, batch_size: int) -> AsyncIterator[Trajectory]: + # 默认走最简单的单轮模式 + rollout_state = await self.rollout_controller.generate(rollout_state) + + reward = 0.0 + if self.judger is not None: + if asyncio.iscoroutinefunction(self.judger): + reward = await self.judger(rollout_state) + else: + reward = self.judger(rollout_state) + rollout_state.reward = reward + return rollout_state + + async def generate_group(self, rollout_state: RolloutState) -> list[RolloutState]: + pending_tasks = [] + for _ in range(self.prompt_repeat_k): + task = asyncio.create_task(self.generate(rollout_state)) + pending_tasks.append(task) + + trajectories = asyncio.gather(*pending_tasks) + return await trajectories + + async def generate_batch(self, batch_size: int) -> list[RolloutState]: data_concurrency = batch_size assert self.dataloader is not None, "Dataloader must be provided for batch generation." pending_tasks = [] for _ in range(data_concurrency): - data_item = self.sample() - task = asyncio.create_task(self.generate(data_item)) + rollout_state = self.sample() + task = asyncio.create_task(self.generate_group(rollout_state)) pending_tasks.append(task) completed_sample_count = 0 From a9c28f8b8a018e44639201810c10a232882c00d7 Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Sat, 31 Jan 2026 10:52:47 +0000 Subject: [PATCH 3/6] add proxy --- xtuner/v2/proxy_async_env_runner.py | 19 ++++++ xtuner/v2/simple_env_runner.py | 99 ++++++++++++++++++++--------- 2 files changed, 87 insertions(+), 31 deletions(-) create mode 100644 xtuner/v2/proxy_async_env_runner.py diff --git a/xtuner/v2/proxy_async_env_runner.py b/xtuner/v2/proxy_async_env_runner.py new file mode 100644 index 000000000..8dfde8a0c --- /dev/null +++ b/xtuner/v2/proxy_async_env_runner.py @@ -0,0 +1,19 @@ +from .simple_env_runner import SimpleEnvRunner +from .rollout_state import RolloutState + +# 用户无感 +class AsyncProxyEnvRuner: + + def __init__(self): + self.base_env_runner: SimpleEnvRunner = None + + def set_base_env_runner(self, base_env_runner: SimpleEnvRunner): + self.base_env_runner = base_env_runner + + async def async_generate_batch(self, + batch_size: int, + prompt_repeat_k: int, + staleness_threshold: float = 0.0, + enable_partial_rollout: bool =False, + ) -> list[RolloutState]: + raise NotImplementedError("Please implement async_generate_batch method for your custom generation strategy.") diff --git a/xtuner/v2/simple_env_runner.py b/xtuner/v2/simple_env_runner.py index 434faadf0..50a16fd51 100644 --- a/xtuner/v2/simple_env_runner.py +++ b/xtuner/v2/simple_env_runner.py @@ -1,36 +1,44 @@ + import asyncio from xtuner.v1.datasets import DataloaderConfig - from .utils import load_function from .rollout_controller import RolloutController from .rollout_state import ProcessorUtilState, RolloutState -# TODO:这个类做的东西有点多,是否需要加一个 base env runner + +# 这个类负责所以定义接口,同时提供一个满足大部分需求的同步运行 rollout。支持单轮,多轮,agent 等场景 +# 异步功能我们假设有两套完全不同的实现,则分别继承这个类进行扩展即可 class SimpleEnvRunner: def __init__(self, rollout_controller: RolloutController, - processor_utils_state: ProcessorUtilState, - dataloader_cfg: DataloaderConfig | None = None, # none 是为了这个 envruner 可以独立运行 + processor_utils_state: ProcessorUtilState | None = None, judger: callable | None = None, # none 是为了这个 envruner 可以独立运行, 可以是简单的 callable, 也可以是 actor worker + dataloader_cfg: DataloaderConfig | None = None, # none 是为了这个 envruner 可以独立运行 generate_external: callable | None = None, + # 最理想状态是:这个类用户是完全无感的,用于只要基于 simple_env_runner 定制化自己的逻辑后 + # 然后传入类似这个 proxy 类就可以实现一种异步策略,实现解耦目的 + async_proxy_runner = None, # 用于异步场景的代理 runner ): - - self.dataloader = None - if dataloader_cfg is not None: - self.dataloader = dataloader_cfg.build() self.rollout_controller = rollout_controller self.judger = judger self.processor_utils_state = processor_utils_state - self.prompt_repeat_k = 1 # 外面传入 - + self.dataloader = None + if dataloader_cfg is not None: + self.dataloader = dataloader_cfg.build() + self.generate_external = generate_external if self.generate_external is not None: self.generate_external = load_function(self.generate_external) - - def sample(self) -> dict: + + self.async_proxy_runner = async_proxy_runner + if self.async_proxy_runner is not None: + # 循环引用,会不会有问题 + self.async_proxy_runner.set_base_env_runner(self) + + def sample_from_dataset(self) -> RolloutState: try: data = next(self.dataloader_iter)[0] except StopIteration: @@ -38,43 +46,57 @@ def sample(self) -> dict: self.dataloader.set_epoch(self.cur_epoch) self.dataloader_iter = iter(self.dataloader) data = next(self.dataloader_iter)[0] - self.reduced_consumed_samples += 1 return data + + # 生成一条样本 + async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: + # 默认走最简单的单轮模式 + rollout_state = await self.rollout_controller.generate(rollout_state) + + reward = 0.0 + if self.judger is not None: + if asyncio.iscoroutinefunction(self.judger): + reward = await self.judger(rollout_state) + else: + reward = self.run_judger_worker(rollout_state) + rollout_state.reward = reward + return rollout_state + async def run_judger_worker(self,rollout_state): + # 可能有多个 judge worker,此时就涉及到调度问题 + # 用户可以重载这个方法,自定义自己调度策略。 + reward = self.judger(rollout_state) + return reward + async def generate(self, rollout_state: RolloutState) -> RolloutState: if self.generate_external is not None: # TODO: 如果走这个分支,估计没有走 partial rollout return await self.generate_external(rollout_state, self.processor_utils_state, self.rollout_controller, self.judger) else: - # 默认走最简单的单轮模式 - rollout_state = await self.rollout_controller.generate(rollout_state) - - reward = 0.0 - if self.judger is not None: - if asyncio.iscoroutinefunction(self.judger): - reward = await self.judger(rollout_state) - else: - reward = self.judger(rollout_state) - rollout_state.reward = reward - return rollout_state + return await self.generate_sample(rollout_state) - async def generate_group(self, rollout_state: RolloutState) -> list[RolloutState]: + # 生成一组样本 + async def generate_group(self, rollout_state: RolloutState, prompt_repeat_k: int) -> list[RolloutState]: pending_tasks = [] - for _ in range(self.prompt_repeat_k): + for _ in range(prompt_repeat_k): task = asyncio.create_task(self.generate(rollout_state)) pending_tasks.append(task) trajectories = asyncio.gather(*pending_tasks) return await trajectories - - async def generate_batch(self, batch_size: int) -> list[RolloutState]: + + # 不可打断式生成一批样本,用于同步场景 + async def generate_batch(self, + batch_size: int, + prompt_repeat_k: int, + ) -> list[RolloutState]: data_concurrency = batch_size assert self.dataloader is not None, "Dataloader must be provided for batch generation." pending_tasks = [] for _ in range(data_concurrency): - rollout_state = self.sample() - task = asyncio.create_task(self.generate_group(rollout_state)) + rollout_state = self.sample_from_dataset() + task = asyncio.create_task(self.generate_group(rollout_state, prompt_repeat_k)) pending_tasks.append(task) completed_sample_count = 0 @@ -93,5 +115,20 @@ async def generate_batch(self, batch_size: int) -> list[RolloutState]: except Exception as e: print(f"Error in generating trajectory: {e}") - # TODO: 如果有超发 return batch_trajectories + + # ===================================================================== + # 以下接口都是异步 rollout 相关的接口 + + # 用于可中断生成场景 + async def async_generate_batch(self, + batch_size: int, + prompt_repeat_k: int, + staleness_threshold: float = 0.0, + enable_partial_rollout: bool =False, + ) -> list[RolloutState]: + return await self.async_proxy_runner.async_generate_batch( + batch_size, + prompt_repeat_k, + staleness_threshold, + enable_partial_rollout) From 95f1a92251b4ce23fe8baf59fa093aff89261f33 Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Sat, 31 Jan 2026 13:24:36 +0000 Subject: [PATCH 4/6] add proxy code --- examples/v2/toolcall_env_demo.py | 5 ++- xtuner/v2/proxy_async_env_runner.py | 56 ++++++++++++++++++++++++++++- xtuner/v2/simple_env_runner.py | 6 ++-- 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/examples/v2/toolcall_env_demo.py b/examples/v2/toolcall_env_demo.py index cc9696ee0..f1ccd33e7 100644 --- a/examples/v2/toolcall_env_demo.py +++ b/examples/v2/toolcall_env_demo.py @@ -29,12 +29,15 @@ "force_cleanup_threshold": 9216, # 9GB } -# 一旦自定义函数,partial rollout 不太好做。 +# 一旦自定义函数,partial rollout 不太好做。不对,好像 partial rollout 很容易就支持了。因为输入的 rollout_state 有 state +# 有了这个 state 就可以判断当前应该走啥流程,是进行 rollout 还是调用工具等等。因为这个是一个状态机,是可以复源的。 async def gsm8k_with_tools_generate(rollout_state: RolloutState, processor_utils_state: ProcessorUtilState, rollout_controller:RolloutController, judger) -> RolloutState: # 可以自己用 tokens,也可以用 message + tools + + # 可以基于输入的 state 判断当前处于什么状态,从而决定接下来的动作。只要输入 state 不是 init 说明肯定开启了 partial rollout prompt = processor_utils_state.tokenizernizer.apply_chat_template( rollout_state.messages, tools=rollout_state.tools if hasattr(rollout_state, "tools") else None, diff --git a/xtuner/v2/proxy_async_env_runner.py b/xtuner/v2/proxy_async_env_runner.py index 8dfde8a0c..4b6f870d9 100644 --- a/xtuner/v2/proxy_async_env_runner.py +++ b/xtuner/v2/proxy_async_env_runner.py @@ -1,3 +1,4 @@ +import asyncio from .simple_env_runner import SimpleEnvRunner from .rollout_state import RolloutState @@ -6,14 +7,67 @@ class AsyncProxyEnvRuner: def __init__(self): self.base_env_runner: SimpleEnvRunner = None + + # 先简单点写 + self.expired_buffer: list[RolloutState] = [] + self.aborted_buffer: list[RolloutState] = [] def set_base_env_runner(self, base_env_runner: SimpleEnvRunner): self.base_env_runner = base_env_runner + def sample_from_expired_buffer(self) -> RolloutState: + pass + + def sample_from_aborted_buffer(self) -> RolloutState: + pass + + # 这个方法应该可以实现所有异步功能的 async def async_generate_batch(self, batch_size: int, prompt_repeat_k: int, + # 这些可能是类输入参数,而不是通过参数传入 staleness_threshold: float = 0.0, enable_partial_rollout: bool =False, ) -> list[RolloutState]: - raise NotImplementedError("Please implement async_generate_batch method for your custom generation strategy.") + # 基于当前内部管理的状态,就可以下一次应该从哪个池子中采样 + # 高度内聚功能模块 + data_concurrency = (1+staleness_threshold)*batch_size + + # 仅仅考虑 partial_rollout 场景 + pending_tasks = [] + if enable_partial_rollout: + # 先从 abort buffer 里采样 + for _ in range(len(self.aborted_buffer)): + rollout_state = self.sample_from_aborted_buffer() + task = asyncio.create_task(self.generate_group(rollout_state, prompt_repeat_k)) + pending_tasks.append(task) + + data_concurrency -= len(pending_tasks) + for _ in range(data_concurrency): + # 最后从数据集中采样 + rollout_state = self.sample_from_dataset() + task = asyncio.create_task(self.generate_group(rollout_state, prompt_repeat_k)) + pending_tasks.append(task) + + completed_sample_count = 0 + batch_trajectories = [] + while completed_sample_count < batch_size: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED) + for task in done_tasks: + try: + traj = await task + if traj is not None: + batch_trajectories.append(traj) + completed_sample_count += 1 + except Exception as e: + print(f"Error in generating trajectory: {e}") + + # 被 abort 的样本放入 buffer + # 好像并不是设置啥额外的例如 save_partial 这种方法来保存中间内容,因为所有东西都可以在 rollout_state 里复原才对。 + # 即使 agent 内部有一套复杂的格式,只要他返回的 rollout_state 带有这部分信息,那就可以复原的。 + self.aborted_buffer.extend([ts for ts in batch_trajectories if ts.is_aborted()]) + + return batch_trajectories # 返回的数据一定是可以训练的 diff --git a/xtuner/v2/simple_env_runner.py b/xtuner/v2/simple_env_runner.py index 50a16fd51..5bc469b2a 100644 --- a/xtuner/v2/simple_env_runner.py +++ b/xtuner/v2/simple_env_runner.py @@ -51,8 +51,9 @@ def sample_from_dataset(self) -> RolloutState: # 生成一条样本 async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: # 默认走最简单的单轮模式 + # 如果有被打断的样本,则有 state 可以表征 rollout_state = await self.rollout_controller.generate(rollout_state) - + reward = 0.0 if self.judger is not None: if asyncio.iscoroutinefunction(self.judger): @@ -70,7 +71,6 @@ async def run_judger_worker(self,rollout_state): async def generate(self, rollout_state: RolloutState) -> RolloutState: if self.generate_external is not None: - # TODO: 如果走这个分支,估计没有走 partial rollout return await self.generate_external(rollout_state, self.processor_utils_state, self.rollout_controller, self.judger) else: return await self.generate_sample(rollout_state) @@ -118,8 +118,6 @@ async def generate_batch(self, return batch_trajectories # ===================================================================== - # 以下接口都是异步 rollout 相关的接口 - # 用于可中断生成场景 async def async_generate_batch(self, batch_size: int, From edec4ed3b8e64928de12c63119310db2759c37cc Mon Sep 17 00:00:00 2001 From: duanyanhui <45005871+YanhuiDua@users.noreply.github.com> Date: Mon, 2 Feb 2026 18:48:10 +0800 Subject: [PATCH 5/6] support async rl (#6) * support async rl * handle last_step_remain_completed_samples in buffer --- xtuner/v2/proxy_async_env_runner.py | 223 +++++++++++++++++++++++----- xtuner/v2/rollout_state.py | 11 +- xtuner/v2/simple_env_runner.py | 74 +++++---- 3 files changed, 238 insertions(+), 70 deletions(-) diff --git a/xtuner/v2/proxy_async_env_runner.py b/xtuner/v2/proxy_async_env_runner.py index 4b6f870d9..fda52aa97 100644 --- a/xtuner/v2/proxy_async_env_runner.py +++ b/xtuner/v2/proxy_async_env_runner.py @@ -1,57 +1,193 @@ import asyncio from .simple_env_runner import SimpleEnvRunner -from .rollout_state import RolloutState +from .rollout_state import RolloutState, Status # 用户无感 -class AsyncProxyEnvRuner: +class ExpiredBuffer: + def __init__(self): + self.buffer: List[List[RolloutState]] = [] + + def add(self, rollout_state: List[RolloutState]): + for rs in rollout_state: + rs.response = "" + rs.response_ids = [] + rs.logprobs = [] + if rs.routed_experts is not None: + # 需要注意新的routed_experts最好不用用ray._internal.free + del res.router_experts + rs.routed_experts = None + + # 2. 重置评价与统计字段 + rs.reward = None + rs.staleness = 0 + + # 3. 重置生命周期状态 + from .rollout_state import Status + rs.state = Status.INIT + self.buffer.append(rollout_state) + + def pop(self) -> list[RolloutState]: + assert self.buffer, "ExpiredBuffer is empty!" + return self.buffer.pop(0) + +class AbortedBuffer: def __init__(self): - self.base_env_runner: SimpleEnvRunner = None + self.buffer: Dict[int, List[List[RolloutState]]] = defaultdict(list) + + def add(self, rollout_state: List[RolloutState]): + group_staleness = max([rs.staleness for rs in rollout_state]) + self.buffer[group_staleness].append(rollout_state) + + def pop(self, enable_partial_rollout) -> list[RolloutState]: + assert self.buffer, "AbortedBuffer is empty!" + highest_staleness = max(self.buffer.keys()) + rollout_states = self.buffer[highest_staleness] + data = rollout_states.pop(0) + if enable_partial_rollout: + for rs in data: + rs.tokens = rs.prompt_ids + rs.response_ids + rs.sample_params.max_tokens = rs.sample_params.max_tokens - len(rs.response_ids) + else: + for rs in data: + rs.response = "" + rs.response_ids = [] + rs.logprobs = [] + if rs.routed_experts is not None: + del rs.routed_experts + rs.routed_experts = None + rs.reward = None + rs.staleness = 0 + from .rollout_state import Status + rs.state = Status.INIT + return data + + def update(self): + new_buffer = defaultdict(list) + for staleness, rollout_states in self.buffer.items(): + new_staleness = staleness + 1 + new_buffer[new_staleness].extend(rollout_states) + self.buffer = new_buffer + +class CompletedBuffer: + def __init__(self): + self.buffer: Dict[int, List[List[RolloutState]]] = defaultdict(list) + + def add(self, rollout_state: List[RolloutState]): + group_staleness = max([rs.staleness for rs in rollout_state]) + self.buffer[group_staleness].append(rollout_state) + + def pop(self) -> list[RolloutState]: + highest_staleness = max(self.buffer.keys()) + rollout_states = self.buffer[highest_staleness] + return rollout_states.pop(0) + + def update(self): + new_buffer = defaultdict(list) + for staleness, rollout_states in self.buffer.items(): + new_staleness = staleness + 1 + new_buffer[new_staleness].extend(rollout_states) + self.buffer = new_buffer + + @property + def length(self) -> int: + return sum(len(v) for v in self.buffer.values()) + +class Buffer: + # 这个功能独立的作为一个类的想法是:expired buffer 和 aborted buffer 可能会有不同的优先级管理, + # 例如,我们现在是根据版本进行管理,版本越旧的样本越先出队,可能未来还有其他的方式,例如按照长度? + # 同时,将过期的样本的管理进行独立,使代码可读性更高一点 + # 每个变量控制的内容也要更加独立: + # enable_partial_rollout: 下次rollout是否进行拼接 + # tail_batch_candidate_step: 多老的样本才会进入expired buffer + # tail_batch_trigger_size: expired buffer的触发采样大小阈值 + + def __init__(self, + enable_partial_rollout: bool = False, + tail_batch_candidate_step: int = 1, + tail_batch_trigger_size: int = 10): + self.expired_buffer: ExpiredBuffer = ExpiredBuffer() + self.aborted_buffer: AbortedBuffer = AbortedBuffer() + self.completed_buffer: CompletedBuffer = CompletedBuffer() + self.enable_tail_batch = tail_batch_candidate_step > 0 + self.enable_partial_rollout = enable_partial_rollout + self.tail_batch_candidate_step = tail_batch_candidate_step + + def add(self, rollout_state: List[RolloutState]): + # rollout_state的版本管理放在哪里?例如一次权重更新后,对Buffer里所有样本的版本进行一次更新,而不是在这里进行更新 + group_staleness = max([rs.staleness for rs in rollout_state]) + group_states = [rs.state for rs in rollout_state] + if self.enable_tail_batch and group_staleness > self.tail_batch_candidate_step: + self.expired_buffer.add(rollout_state) + elif all(state == Status.COMPLETED for state in group_states): + self.completed_buffer.add(rollout_state) + else: + self.aborted_buffer.add(rollout_state) + + def get_sample_func(self, data_sampler: DataSampler) -> RolloutState: + use_expired = self.enable_tail_batch and len(self.expired_buffer) > 0 + self.update() + + def _sample(): + if use_expired and self.expired_buffer: + return self.expired_buffer.pop() + elif self.aborted_buffer: + return self.aborted_buffer.pop(self.enable_partial_rollout) + else: + return data_sampler.sample_from_dataset() - # 先简单点写 - self.expired_buffer: list[RolloutState] = [] - self.aborted_buffer: list[RolloutState] = [] + return _sample + + def update(self): + if self.enable_partial_rollout: + self.completed_buffer.update() + else: + while self.completed_buffer.length > 0: + state = self.completed_buffer.pop() + self.aborted_buffer.add(state) + self.aborted_buffer.update() + + +class AsyncProxyEnvRuner: + def __init__( + self, + staleness_threshold: float = 0.0, + enable_partial_rollout: bool = False, + tail_batch_trigger_size: int = 0, + tail_batch_candidate_step: int = 0, + ): + self.base_env_runner: SimpleEnvRunner = None + self.buffer = Buffer(enable_partial_rollout, tail_batch_candidate_step, tail_batch_trigger_size) + self.staleness_threshold = staleness_threshold def set_base_env_runner(self, base_env_runner: SimpleEnvRunner): self.base_env_runner = base_env_runner - - def sample_from_expired_buffer(self) -> RolloutState: - pass - def sample_from_aborted_buffer(self) -> RolloutState: - pass - # 这个方法应该可以实现所有异步功能的 async def async_generate_batch(self, + data_sampler: DataSampler, batch_size: int, prompt_repeat_k: int, - # 这些可能是类输入参数,而不是通过参数传入 - staleness_threshold: float = 0.0, - enable_partial_rollout: bool =False, - ) -> list[RolloutState]: + ) -> List[List[RolloutState]]: # 基于当前内部管理的状态,就可以下一次应该从哪个池子中采样 # 高度内聚功能模块 - data_concurrency = (1+staleness_threshold)*batch_size + last_step_remain_completed_samples = self.buffer.completed_buffer.length + data_concurrency = (1 + self.staleness_threshold) * (batch_size - last_step_remain_completed_samples) + completed_sample_count = 0 + sample_func = self.buffer.get_sample_func(data_sampler, prompt_repeat_k) - # 仅仅考虑 partial_rollout 场景 pending_tasks = [] - if enable_partial_rollout: - # 先从 abort buffer 里采样 - for _ in range(len(self.aborted_buffer)): - rollout_state = self.sample_from_aborted_buffer() - task = asyncio.create_task(self.generate_group(rollout_state, prompt_repeat_k)) - pending_tasks.append(task) - data_concurrency -= len(pending_tasks) + for _ in range(last_step_remain_completed_samples): + traj = self.buffer.completed_buffer.pop() + yield traj + for _ in range(data_concurrency): - # 最后从数据集中采样 - rollout_state = self.sample_from_dataset() - task = asyncio.create_task(self.generate_group(rollout_state, prompt_repeat_k)) + task = asyncio.create_task(self.base_env_runner.generate_group(sample_func)) + # task = asyncio.create_task(self.generate_group(data_sampler.sample(), prompt_repeat_k)) pending_tasks.append(task) - - completed_sample_count = 0 - batch_trajectories = [] - while completed_sample_count < batch_size: + + while completed_sample_count < data_concurrency: if not pending_tasks: print("All tasks are done but not enough samples collected.") break @@ -60,14 +196,23 @@ async def async_generate_batch(self, try: traj = await task if traj is not None: - batch_trajectories.append(traj) completed_sample_count += 1 + if completed_sample_count <= batch_size: + yield traj + else: + self.buffer.add(traj) except Exception as e: print(f"Error in generating trajectory: {e}") - # 被 abort 的样本放入 buffer - # 好像并不是设置啥额外的例如 save_partial 这种方法来保存中间内容,因为所有东西都可以在 rollout_state 里复原才对。 - # 即使 agent 内部有一套复杂的格式,只要他返回的 rollout_state 带有这部分信息,那就可以复原的。 - self.aborted_buffer.extend([ts for ts in batch_trajectories if ts.is_aborted()]) - - return batch_trajectories # 返回的数据一定是可以训练的 + await self.base_env_runner.rollout_controller.pause() + while len(pending_tasks) > 0: + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) + for task in done_tasks: + try: + abort_traj = await task + self.buffer.add(abort_traj) + except Exception as e: + print(f"Error while pausing task: {e}") + if len(pending_tasks) > 0: + await self.base_env_runner.rollout_controller.pause() + await asyncio.sleep(1) diff --git a/xtuner/v2/rollout_state.py b/xtuner/v2/rollout_state.py index 8b816524e..f0890e9a5 100644 --- a/xtuner/v2/rollout_state.py +++ b/xtuner/v2/rollout_state.py @@ -21,22 +21,25 @@ class Status(Enum): class RolloutState: # dataset 输出必须 message: list - tokens: list[int] + tokens: list[int] # 每一次实际输入 uid: int session_id: int | None = None prompt_ids: list[int] response: str - response_ids: list[int] + response_ids: list[int] # 每一次实际输出,覆盖写 logprobs: list[float] routed_experts: list[int] | None = None reward: float | list[float] | list[dict] | None = None - loss_mask: list[int] | None = None + loss_mask: list[int] | None = None # tokens + response_ids的长度 state: Status = Status.INIT sample_parms: SampleParams | None = None tools: list | None = None tool_choice: str | None = None - + mm_infer_info: dict[str, Any] + mm_train_info: dict[str, Any] + finish_reason: str | None = None + staleness: int = 0 # TODO: 这个对象存在的意义是啥?暂时不用,否则会导致内部循环对象不一致, partial rollout 也不好弄 @dataclass diff --git a/xtuner/v2/simple_env_runner.py b/xtuner/v2/simple_env_runner.py index 5bc469b2a..2cb190b89 100644 --- a/xtuner/v2/simple_env_runner.py +++ b/xtuner/v2/simple_env_runner.py @@ -6,7 +6,32 @@ from .rollout_controller import RolloutController from .rollout_state import ProcessorUtilState, RolloutState +class DataSampler: + def __init__(self, dataloader_config: DataloaderConfig): + self.dataloader_config = dataloader_config + self.dataloader = None + self.dataloader_iter = None + self.cur_epoch = 0 + + def sample_from_dataset(self, prompt_repeat_k: int) -> RolloutState: + try: + data = next(self.dataloader_iter)[0] + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + # 根据 prompt_repeat_k 进行数据扩展 + group_data = [] + for _ in range(prompt_repeat_k): + group_data.append(data) + return group_data + + def resume(self): + pass + def save(self): + pass # 这个类负责所以定义接口,同时提供一个满足大部分需求的同步运行 rollout。支持单轮,多轮,agent 等场景 # 异步功能我们假设有两套完全不同的实现,则分别继承这个类进行扩展即可 @@ -15,7 +40,6 @@ def __init__(self, rollout_controller: RolloutController, processor_utils_state: ProcessorUtilState | None = None, judger: callable | None = None, # none 是为了这个 envruner 可以独立运行, 可以是简单的 callable, 也可以是 actor worker - dataloader_cfg: DataloaderConfig | None = None, # none 是为了这个 envruner 可以独立运行 generate_external: callable | None = None, # 最理想状态是:这个类用户是完全无感的,用于只要基于 simple_env_runner 定制化自己的逻辑后 # 然后传入类似这个 proxy 类就可以实现一种异步策略,实现解耦目的 @@ -25,10 +49,6 @@ def __init__(self, self.judger = judger self.processor_utils_state = processor_utils_state - self.dataloader = None - if dataloader_cfg is not None: - self.dataloader = dataloader_cfg.build() - self.generate_external = generate_external if self.generate_external is not None: self.generate_external = load_function(self.generate_external) @@ -49,7 +69,7 @@ def sample_from_dataset(self) -> RolloutState: return data # 生成一条样本 - async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: + async def generate_single_sample(self, rollout_state: RolloutState) -> RolloutState: # 默认走最简单的单轮模式 # 如果有被打断的样本,则有 state 可以表征 rollout_state = await self.rollout_controller.generate(rollout_state) @@ -73,35 +93,39 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: if self.generate_external is not None: return await self.generate_external(rollout_state, self.processor_utils_state, self.rollout_controller, self.judger) else: - return await self.generate_sample(rollout_state) + return await self.generate_single_sample(rollout_state) # 生成一组样本 - async def generate_group(self, rollout_state: RolloutState, prompt_repeat_k: int) -> list[RolloutState]: + + async def generate_group(self, sample_func, prompt_repeat_k: int) -> list[RolloutState]: pending_tasks = [] - for _ in range(prompt_repeat_k): - task = asyncio.create_task(self.generate(rollout_state)) + + group_rollout_state = sample_func(prompt_repeat_k) + for rs in range(group_rollout_state): + task = asyncio.create_task(self.generate(rs)) pending_tasks.append(task) - trajectories = asyncio.gather(*pending_tasks) - return await trajectories - + generated_states = asyncio.gather(*pending_tasks) + + group_responses = await generated_states + return group_responses + # 不可打断式生成一批样本,用于同步场景 async def generate_batch(self, + data_sampler: DataSampler, batch_size: int, prompt_repeat_k: int, - ) -> list[RolloutState]: + ) -> List[List[RolloutState]]: data_concurrency = batch_size - assert self.dataloader is not None, "Dataloader must be provided for batch generation." + sample_func = data_sampler.sample_from_dataset pending_tasks = [] for _ in range(data_concurrency): - rollout_state = self.sample_from_dataset() - task = asyncio.create_task(self.generate_group(rollout_state, prompt_repeat_k)) + task = asyncio.create_task(self.generate_group(sample_func, prompt_repeat_k)) pending_tasks.append(task) completed_sample_count = 0 - batch_trajectories = [] - while completed_sample_count < batch_size: + while completed_sample_count < data_concurrency: if not pending_tasks: print("All tasks are done but not enough samples collected.") break @@ -110,23 +134,19 @@ async def generate_batch(self, try: traj = await task if traj is not None: - batch_trajectories.append(traj) + yield traj completed_sample_count += 1 except Exception as e: print(f"Error in generating trajectory: {e}") - return batch_trajectories - # ===================================================================== # 用于可中断生成场景 async def async_generate_batch(self, + data_sampler: DataSampler, batch_size: int, prompt_repeat_k: int, - staleness_threshold: float = 0.0, - enable_partial_rollout: bool =False, ) -> list[RolloutState]: return await self.async_proxy_runner.async_generate_batch( + data_sampler, batch_size, - prompt_repeat_k, - staleness_threshold, - enable_partial_rollout) + prompt_repeat_k) From cf4e9db9d85c13a820239043fa8fc582a7cdd422 Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Wed, 4 Feb 2026 10:54:09 +0000 Subject: [PATCH 6/6] add trainer --- examples/v2/toolcall_env_demo.py | 2 + xtuner/v2/colocate_rl_trainer.py | 64 +++++++++++++++++++++++++++ xtuner/v2/disaggregated_rl_trainer.py | 57 ++++++++++++++++++++++++ xtuner/v2/proxy_async_env_runner.py | 5 ++- xtuner/v2/rollout_controller.py | 8 +--- xtuner/v2/rollout_state.py | 12 ++--- xtuner/v2/simple_env_runner.py | 7 +-- xtuner/v2/weight_controller.py | 12 +++++ 8 files changed, 145 insertions(+), 22 deletions(-) create mode 100644 xtuner/v2/colocate_rl_trainer.py create mode 100644 xtuner/v2/disaggregated_rl_trainer.py create mode 100644 xtuner/v2/weight_controller.py diff --git a/examples/v2/toolcall_env_demo.py b/examples/v2/toolcall_env_demo.py index f1ccd33e7..e12c8c3a4 100644 --- a/examples/v2/toolcall_env_demo.py +++ b/examples/v2/toolcall_env_demo.py @@ -151,4 +151,6 @@ async def gsm8k_with_tools_generate(rollout_state: RolloutState, ray.get(rollout_controller.shutdown.remote(), timeout=300) + from xtuner.v2.proxy_async_env_runner import ProxyAsyncEnvRunner + diff --git a/xtuner/v2/colocate_rl_trainer.py b/xtuner/v2/colocate_rl_trainer.py new file mode 100644 index 000000000..4daeea322 --- /dev/null +++ b/xtuner/v2/colocate_rl_trainer.py @@ -0,0 +1,64 @@ + + +# 共卡,是否异步应该不区分的 +class ColocateRLTrainer: + def __init__( + self, + config, + env_runner, # 不允许有多少 env runner, 多 task 场景下,应该传入的是一个 composite env runner,由用户自己组织。这样这个类才能简洁通用 + weight_controller, + + rollout_controller, + training_controller # ppo 算法的所有训练细节都是在这个类里面做,这个 trainer 也感知不到 + ): + self.config = config + self.env_runner = env_runner + self.weight_controller = weight_controller + self.training_controller = training_controller + self.rollout_controller = rollout_controller + self.replay_buffer = [] # 这个场景下,实际上 list 就足够了 + self.batch_size: list[int] = config.batch_size + + self.env_runner.set_controller( + rollout_controller=self.rollout_controller, + training_controller=self.training_controller) + self.weight_controller.set_controllers( + rollout_controller=self.rollout_controller, + training_controller=self.training_controller) + + self.training_steps_per_epoch=1 + self.total_steps = 100 + + def train_loop(self): + for rollout_id in range(self.total_steps): + self.train_step(rollout_id) + + def train_step(self, rollout_id): + self.replay_buffer=[] + + # offload train controller to cpu + self.training_controller.offload_to_cpu() + # load rollout_controller + self.rollout_controller.load_to_device() + + # Collect rollouts + for trajectory in self.env_runner.generate_batch(self.batch_size): + self.replay_buffer.add(trajectory) + + # offload rollout_controller to cpu + self.rollout_controller.offload_to_cpu() + # load train controller + self.training_controller.load_to_device() + + # Train the model + for _ in range(self.training_steps_per_epoch): + batch = self.replay_buffer + train_batch= self.convert_rollout_batch_to_train_batch(batch) + self.training_controller.fit(train_batch) + + # ipc 和 nccl 应该是两套不同的实现,但是接口一致 + self.weight_controller.update_weights() + + def convert_rollout_batch_to_train_batch(self, batch): + # 这里假设 rollout batch 和 train batch 是一样的 + return batch diff --git a/xtuner/v2/disaggregated_rl_trainer.py b/xtuner/v2/disaggregated_rl_trainer.py new file mode 100644 index 000000000..84e691072 --- /dev/null +++ b/xtuner/v2/disaggregated_rl_trainer.py @@ -0,0 +1,57 @@ + + +# 非共卡,是否异步理论上是不区分的 +class DisaggregatedRLTrainer: + def __init__( + self, + config, + env_runner, # 不允许有多少 env runner, 多 task 场景下,应该传入的是一个 composite env runner,由用户自己组织。这样这个类才能简洁通用 + update_weighter, + + rollout_controller, + training_controller + ): + self.config = config + self.env_runner = env_runner + self.update_weighter = update_weighter + self.training_controller = training_controller + self.rollout_controller = rollout_controller + self.replay_buffer = [] # 这个场景下,实际上 list 就足够了 + self.batch_size: list[int] = config.batch_size + + self.env_runner.set_controller( + rollout_controller=self.rollout_controller, + training_controller=self.training_controller) + self.update_weighter.set_controllers( + rollout_controller=self.rollout_controller, + training_controller=self.training_controller) + + self.training_steps_per_epoch=1 + self.total_steps = 100 + self.require_batches = config.batch_size // 4 + + def train_loop(self): + for rollout_id in range(self.total_steps): + self.train_step(rollout_id) + + def train_step(self, rollout_id): + self.replay_buffer=[] + # Collect rollouts + for trajectory in self.env_runner.generate_batch(self.batch_size): + self.replay_buffer.add(trajectory) + + # 达到指定的 batch 数量后,开始训练 + if len(self.replay_buffer)>= self.require_batches: + # Train the model + for _ in range(self.training_steps_per_epoch): + # 从 replay buffer 采样 batch 后,replay_buffer 内部已经训练的数据要清空 + batch = self.replay_buffer.pop(self.require_batches) + train_batch= self.convert_rollout_batch_to_train_batch(batch) + self.training_controller.fit(train_batch) + + self.sync_weights() + self.save_ckpt() + + def convert_rollout_batch_to_train_batch(self, batch): + # 这里假设 rollout batch 和 train batch 是一样的 + return batch diff --git a/xtuner/v2/proxy_async_env_runner.py b/xtuner/v2/proxy_async_env_runner.py index fda52aa97..6277e0c46 100644 --- a/xtuner/v2/proxy_async_env_runner.py +++ b/xtuner/v2/proxy_async_env_runner.py @@ -1,6 +1,9 @@ import asyncio from .simple_env_runner import SimpleEnvRunner from .rollout_state import RolloutState, Status +from typing import List, Dict +from collections import defaultdict +from .simple_env_runner import DataSampler # 用户无感 class ExpiredBuffer: @@ -168,7 +171,7 @@ async def async_generate_batch(self, data_sampler: DataSampler, batch_size: int, prompt_repeat_k: int, - ) -> List[List[RolloutState]]: + ): # 基于当前内部管理的状态,就可以下一次应该从哪个池子中采样 # 高度内聚功能模块 last_step_remain_completed_samples = self.buffer.completed_buffer.length diff --git a/xtuner/v2/rollout_controller.py b/xtuner/v2/rollout_controller.py index b945bf8b5..52ff7fca0 100644 --- a/xtuner/v2/rollout_controller.py +++ b/xtuner/v2/rollout_controller.py @@ -1,11 +1,5 @@ from xtuner.v1.ray.rollout import RolloutController as V1RolloutController -from .rollout_state import RolloutState, Status - -reason_map = { - "length": Status.COMPLETED, - 'aborted': Status.ABORTED, - "failed": Status.FAILED, -} +from .rollout_state import RolloutState # 临时方案 class RolloutController(V1RolloutController): diff --git a/xtuner/v2/rollout_state.py b/xtuner/v2/rollout_state.py index f0890e9a5..4ff873f0f 100644 --- a/xtuner/v2/rollout_state.py +++ b/xtuner/v2/rollout_state.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoProcessor, PreTrainedTokenizerBase from transformers.image_processing_utils import ProcessorMixin from xtuner.v1.ray.rollout.controller import SampleParams -from dataclasses import field +from typing import Any class Status(Enum): @@ -19,7 +19,6 @@ class Status(Enum): @dataclass class RolloutState: - # dataset 输出必须 message: list tokens: list[int] # 每一次实际输入 @@ -31,7 +30,7 @@ class RolloutState: logprobs: list[float] routed_experts: list[int] | None = None reward: float | list[float] | list[dict] | None = None - loss_mask: list[int] | None = None # tokens + response_ids的长度 + loss_mask: list[int] | None = None # tokens + response_ids 的长度 state: Status = Status.INIT sample_parms: SampleParams | None = None tools: list | None = None @@ -40,12 +39,7 @@ class RolloutState: mm_train_info: dict[str, Any] finish_reason: str | None = None staleness: int = 0 - -# TODO: 这个对象存在的意义是啥?暂时不用,否则会导致内部循环对象不一致, partial rollout 也不好弄 -@dataclass -class Trajectory: - env: str = 'default' - rollout_state: RolloutState | list[RolloutState] + extra_fields: dict[str, Any] | None = None def load_tokenizer(name_or_path: str, **kwargs): diff --git a/xtuner/v2/simple_env_runner.py b/xtuner/v2/simple_env_runner.py index 2cb190b89..114fa5dc2 100644 --- a/xtuner/v2/simple_env_runner.py +++ b/xtuner/v2/simple_env_runner.py @@ -1,5 +1,4 @@ - import asyncio from xtuner.v1.datasets import DataloaderConfig from .utils import load_function @@ -96,7 +95,6 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: return await self.generate_single_sample(rollout_state) # 生成一组样本 - async def generate_group(self, sample_func, prompt_repeat_k: int) -> list[RolloutState]: pending_tasks = [] @@ -115,7 +113,7 @@ async def generate_batch(self, data_sampler: DataSampler, batch_size: int, prompt_repeat_k: int, - ) -> List[List[RolloutState]]: + ): data_concurrency = batch_size sample_func = data_sampler.sample_from_dataset @@ -139,13 +137,12 @@ async def generate_batch(self, except Exception as e: print(f"Error in generating trajectory: {e}") - # ===================================================================== # 用于可中断生成场景 async def async_generate_batch(self, data_sampler: DataSampler, batch_size: int, prompt_repeat_k: int, - ) -> list[RolloutState]: + ): return await self.async_proxy_runner.async_generate_batch( data_sampler, batch_size, diff --git a/xtuner/v2/weight_controller.py b/xtuner/v2/weight_controller.py new file mode 100644 index 000000000..16f010119 --- /dev/null +++ b/xtuner/v2/weight_controller.py @@ -0,0 +1,12 @@ + + +class WeightController: + def __init__(self): + pass + + def set_controllers(self, rollout_controller, training_controller): + self.rollout_controller = rollout_controller + self.training_controller = training_controller + + def update_weights(self): + pass