diff --git a/examples/v2/single_turn_env_demo.py b/examples/v2/single_turn_env_demo.py new file mode 100644 index 000000000..509a3d630 --- /dev/null +++ b/examples/v2/single_turn_env_demo.py @@ -0,0 +1,74 @@ +import ray +import os +from uuid import uuid4 + +os.environ["XTUNER_USE_LMDEPLOY"] = "1" + +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v2.rollout_controller import RolloutController +from xtuner.v2.simple_env_runner import SimpleEnvRunner +from xtuner.v2.rollout_state import ProcessorUtilState, RolloutState +from xtuner.v1.ray.rollout.controller import SampleParams +from xtuner.v1.ray.judger.gsm8k import compute_reward + + +if __name__ == '__main__': + ray.init(num_cpus=80, ignore_reinit_error=True) + + # model_path = '/mnt/shared-storage-user/llmrazor-share/model/intern-s1-mini-hha-fix_tokenizer' + model_path ='/mnt/shared-storage-user/llmrazor-share/model/Qwen3-8B' + + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=1, # 1 or 8 + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB + ) + pg = AutoAcceleratorWorkers.build_placement_group(resources) + + # 2. rollout + rollout_config = RolloutConfig( + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=1, + expert_parallel_size=1, + gpu_memory_utilization=0.75 + ) + + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + + processor_utils_state = ProcessorUtilState(hf_checkpoint=model_path, sample_params=SampleParams()) + simple_env_runner = ray.remote(SimpleEnvRunner).remote(rollout_controller, processor_utils_state=processor_utils_state) + + # prompt = [{"content": [{"image_url": {"image_wh": [404, 162], + # "url": "images/test_2.jpg"}, + # "type": "image_url"}, + # { + # "text": "Find the area of the figure to the nearest tenth. You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \\boxed{}.", + # "type": "text"}], + # "role": "user"}] + # extra_info = {'media_root': '/mnt/shared-storage-user/llmrazor-share/data/geometry3k/'} + + prompt = [{"role": "user", "content": 'Calculate 13+24=', "type": "text"}] + data_item = {'prompt': prompt} + + input_ids = processor_utils_state.tokenizer.apply_chat_template(data_item["prompt"], return_tensors="pt")["input_ids"][0].tolist() + rollout_state = RolloutState(uid=uuid4().int, tokens=input_ids) + + # 生成单条 + res1 = ray.get(simple_env_runner.generate.remote(rollout_state)) + print("Response from infer:", res1) + + # 生成一组 + res1 = ray.get(simple_env_runner.generate_group.remote(rollout_state)) + print("Response from infer:", res1) + + # 生成多条 + res1 = ray.get(simple_env_runner.generate_batch.remote(batch_size=64)) + print("Response from infer:", res1) + + ray.get(rollout_controller.shutdown.remote(), timeout=300) + + diff --git a/examples/v2/toolcall_env_demo.py b/examples/v2/toolcall_env_demo.py new file mode 100644 index 000000000..e12c8c3a4 --- /dev/null +++ b/examples/v2/toolcall_env_demo.py @@ -0,0 +1,156 @@ +import ray +import os +from uuid import uuid4 +from copy import deepcopy + +os.environ["XTUNER_USE_LMDEPLOY"] = "1" + +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v2.rollout_controller import RolloutController +from xtuner.v2.simple_env_runner import SimpleEnvRunner +from xtuner.v2.rollout_state import ProcessorUtilState, RolloutState +from xtuner.v1.ray.rollout.controller import SampleParams +from xtuner.v1.ray.judger.gsm8k import compute_reward + + +TOOL_CONFIGS = { + "max_turns": 16, + "max_tool_calls": 16, + "tool_concurrency": 32, # Aggressive: 32 concurrent processes + # Python interpreter settings + "python_timeout": 120, # 2 minutes for complex calculations + "python_memory_limit": "4GB", # 4GB per Python process + "python_cpu_limit": 1, + # Memory management settings + "max_memory_usage": 12288, # 12GB total (75% of 16GB) + "cleanup_threshold": 6144, # 6GB + "aggressive_cleanup_threshold": 3072, # 3GB + "force_cleanup_threshold": 9216, # 9GB +} + +# 一旦自定义函数,partial rollout 不太好做。不对,好像 partial rollout 很容易就支持了。因为输入的 rollout_state 有 state +# 有了这个 state 就可以判断当前应该走啥流程,是进行 rollout 还是调用工具等等。因为这个是一个状态机,是可以复源的。 +async def gsm8k_with_tools_generate(rollout_state: RolloutState, + processor_utils_state: ProcessorUtilState, + rollout_controller:RolloutController, + judger) -> RolloutState: + # 可以自己用 tokens,也可以用 message + tools + + # 可以基于输入的 state 判断当前处于什么状态,从而决定接下来的动作。只要输入 state 不是 init 说明肯定开启了 partial rollout + prompt = processor_utils_state.tokenizernizer.apply_chat_template( + rollout_state.messages, + tools=rollout_state.tools if hasattr(rollout_state, "tools") else None, + tokenize=False, + add_generation_prompt=True, + ) + prompt_tokens_ids = processor_utils_state.tokenizer.tokenize(prompt)["input_ids"] + + response_token_ids = [] + loss_masks = [] + tool_call_count = 0 # Track actual tool call rounds + rollout_log_probs=[] + init_rollout_state = deepcopy(rollout_state) + for turn in range(TOOL_CONFIGS["max_turns"]): + current_token_ids = prompt_tokens_ids + response_token_ids + rollout_state.tokens = current_token_ids + init_rollout_state.tokens = rollout_state.tokens + + # TODO: 需要注意 ray actor 数据传输数据量,原则上他不需要的不应该传入 + # 发送给 rollout_controller 的对象只需要追加 tokens 即可,其余内容他不考虑 + + # 相比于直接发送 post 请求,确实会麻烦一些,学习成本高一些 + # 但是由于有一层管理,少掉了给 n 个 url 然后自己路由的麻烦 + # TODO: 是否要换成 url post 方式? + rollout_state = await rollout_controller.rollout.remote(rollout_state) + + init_rollout_state.state = rollout_state.state + + response_token_ids += rollout_state.response_ids + rollout_log_probs += rollout_state.logprobs + loss_masks += [1] * len(rollout_state.response_ids) + + if rollout_state.state == RolloutState.State.ABORTED: + # rollout_state 内容转移到 init_rollout_state 中,他是全局保存的 + init_rollout_state.state = rollout_state.state + init_rollout_state.response_ids = response_token_ids + init_rollout_state.logprobs = rollout_log_probs + init_rollout_state.loss_mask = loss_masks + return init_rollout_state + + # 执行工具 + next_obs, done = await execute_predictions(rollout_state.response) + if done: + break + + obs_tokens_ids = processor_utils_state.tokenizer(next_obs, add_special_tokens=False)["input_ids"] + response += next_obs + response_token_ids += obs_tokens_ids + rollout_log_probs += [0.0] * len(obs_tokens_ids) + loss_masks += [0] * len(obs_tokens_ids) + + if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]: + break + + # rollout_state 内容转移到 init_rollout_state 中,他是全局保存的 + init_rollout_state.response_ids = response_token_ids + init_rollout_state.logprobs = rollout_log_probs + init_rollout_state.loss_mask = loss_masks + return init_rollout_state + + +if __name__ == '__main__': + + ray.init(num_cpus=80, ignore_reinit_error=True) + + # model_path = '/mnt/shared-storage-user/llmrazor-share/model/intern-s1-mini-hha-fix_tokenizer' + model_path ='/mnt/shared-storage-user/llmrazor-share/model/Qwen3-8B' + + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=1, # 1 or 8 + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB + ) + pg = AutoAcceleratorWorkers.build_placement_group(resources) + + # 2. rollout + rollout_config = RolloutConfig( + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=1, + expert_parallel_size=1, + gpu_memory_utilization=0.75 + ) + + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + + processor_utils_state = ProcessorUtilState(hf_checkpoint=model_path, sample_params=SampleParams()) + simple_env_runner = ray.remote(SimpleEnvRunner).remote(rollout_controller, + processor_utils_state=processor_utils_state, + generate_external=gsm8k_with_tools_generate) + + prompt = [{"role": "user", "content": 'Calculate 13+24=', "type": "text"}] + data_item = {'prompt': prompt} + + input_ids = processor_utils_state.tokenizer.apply_chat_template(data_item["prompt"], return_tensors="pt")["input_ids"][0].tolist() + rollout_state = RolloutState(uid=uuid4().int, tokens=input_ids) + + # 生成单条 + res1 = ray.get(simple_env_runner.generate.remote(rollout_state)) + print("Response from infer:", res1) + + # 生成一组 + res1 = ray.get(simple_env_runner.generate_group.remote(rollout_state)) + print("Response from infer:", res1) + + # 生成多条 + res1 = ray.get(simple_env_runner.generate_batch.remote(batch_size=64)) + print("Response from infer:", res1) + + ray.get(rollout_controller.shutdown.remote(), timeout=300) + + from xtuner.v2.proxy_async_env_runner import ProxyAsyncEnvRunner + + diff --git a/xtuner/v2/colocate_rl_trainer.py b/xtuner/v2/colocate_rl_trainer.py new file mode 100644 index 000000000..4daeea322 --- /dev/null +++ b/xtuner/v2/colocate_rl_trainer.py @@ -0,0 +1,64 @@ + + +# 共卡,是否异步应该不区分的 +class ColocateRLTrainer: + def __init__( + self, + config, + env_runner, # 不允许有多少 env runner, 多 task 场景下,应该传入的是一个 composite env runner,由用户自己组织。这样这个类才能简洁通用 + weight_controller, + + rollout_controller, + training_controller # ppo 算法的所有训练细节都是在这个类里面做,这个 trainer 也感知不到 + ): + self.config = config + self.env_runner = env_runner + self.weight_controller = weight_controller + self.training_controller = training_controller + self.rollout_controller = rollout_controller + self.replay_buffer = [] # 这个场景下,实际上 list 就足够了 + self.batch_size: list[int] = config.batch_size + + self.env_runner.set_controller( + rollout_controller=self.rollout_controller, + training_controller=self.training_controller) + self.weight_controller.set_controllers( + rollout_controller=self.rollout_controller, + training_controller=self.training_controller) + + self.training_steps_per_epoch=1 + self.total_steps = 100 + + def train_loop(self): + for rollout_id in range(self.total_steps): + self.train_step(rollout_id) + + def train_step(self, rollout_id): + self.replay_buffer=[] + + # offload train controller to cpu + self.training_controller.offload_to_cpu() + # load rollout_controller + self.rollout_controller.load_to_device() + + # Collect rollouts + for trajectory in self.env_runner.generate_batch(self.batch_size): + self.replay_buffer.add(trajectory) + + # offload rollout_controller to cpu + self.rollout_controller.offload_to_cpu() + # load train controller + self.training_controller.load_to_device() + + # Train the model + for _ in range(self.training_steps_per_epoch): + batch = self.replay_buffer + train_batch= self.convert_rollout_batch_to_train_batch(batch) + self.training_controller.fit(train_batch) + + # ipc 和 nccl 应该是两套不同的实现,但是接口一致 + self.weight_controller.update_weights() + + def convert_rollout_batch_to_train_batch(self, batch): + # 这里假设 rollout batch 和 train batch 是一样的 + return batch diff --git a/xtuner/v2/disaggregated_rl_trainer.py b/xtuner/v2/disaggregated_rl_trainer.py new file mode 100644 index 000000000..84e691072 --- /dev/null +++ b/xtuner/v2/disaggregated_rl_trainer.py @@ -0,0 +1,57 @@ + + +# 非共卡,是否异步理论上是不区分的 +class DisaggregatedRLTrainer: + def __init__( + self, + config, + env_runner, # 不允许有多少 env runner, 多 task 场景下,应该传入的是一个 composite env runner,由用户自己组织。这样这个类才能简洁通用 + update_weighter, + + rollout_controller, + training_controller + ): + self.config = config + self.env_runner = env_runner + self.update_weighter = update_weighter + self.training_controller = training_controller + self.rollout_controller = rollout_controller + self.replay_buffer = [] # 这个场景下,实际上 list 就足够了 + self.batch_size: list[int] = config.batch_size + + self.env_runner.set_controller( + rollout_controller=self.rollout_controller, + training_controller=self.training_controller) + self.update_weighter.set_controllers( + rollout_controller=self.rollout_controller, + training_controller=self.training_controller) + + self.training_steps_per_epoch=1 + self.total_steps = 100 + self.require_batches = config.batch_size // 4 + + def train_loop(self): + for rollout_id in range(self.total_steps): + self.train_step(rollout_id) + + def train_step(self, rollout_id): + self.replay_buffer=[] + # Collect rollouts + for trajectory in self.env_runner.generate_batch(self.batch_size): + self.replay_buffer.add(trajectory) + + # 达到指定的 batch 数量后,开始训练 + if len(self.replay_buffer)>= self.require_batches: + # Train the model + for _ in range(self.training_steps_per_epoch): + # 从 replay buffer 采样 batch 后,replay_buffer 内部已经训练的数据要清空 + batch = self.replay_buffer.pop(self.require_batches) + train_batch= self.convert_rollout_batch_to_train_batch(batch) + self.training_controller.fit(train_batch) + + self.sync_weights() + self.save_ckpt() + + def convert_rollout_batch_to_train_batch(self, batch): + # 这里假设 rollout batch 和 train batch 是一样的 + return batch diff --git a/xtuner/v2/proxy_async_env_runner.py b/xtuner/v2/proxy_async_env_runner.py new file mode 100644 index 000000000..6277e0c46 --- /dev/null +++ b/xtuner/v2/proxy_async_env_runner.py @@ -0,0 +1,221 @@ +import asyncio +from .simple_env_runner import SimpleEnvRunner +from .rollout_state import RolloutState, Status +from typing import List, Dict +from collections import defaultdict +from .simple_env_runner import DataSampler + +# 用户无感 +class ExpiredBuffer: + def __init__(self): + self.buffer: List[List[RolloutState]] = [] + + def add(self, rollout_state: List[RolloutState]): + for rs in rollout_state: + rs.response = "" + rs.response_ids = [] + rs.logprobs = [] + if rs.routed_experts is not None: + # 需要注意新的routed_experts最好不用用ray._internal.free + del res.router_experts + rs.routed_experts = None + + # 2. 重置评价与统计字段 + rs.reward = None + rs.staleness = 0 + + # 3. 重置生命周期状态 + from .rollout_state import Status + rs.state = Status.INIT + self.buffer.append(rollout_state) + + def pop(self) -> list[RolloutState]: + assert self.buffer, "ExpiredBuffer is empty!" + return self.buffer.pop(0) + + +class AbortedBuffer: + def __init__(self): + self.buffer: Dict[int, List[List[RolloutState]]] = defaultdict(list) + + def add(self, rollout_state: List[RolloutState]): + group_staleness = max([rs.staleness for rs in rollout_state]) + self.buffer[group_staleness].append(rollout_state) + + def pop(self, enable_partial_rollout) -> list[RolloutState]: + assert self.buffer, "AbortedBuffer is empty!" + highest_staleness = max(self.buffer.keys()) + rollout_states = self.buffer[highest_staleness] + data = rollout_states.pop(0) + if enable_partial_rollout: + for rs in data: + rs.tokens = rs.prompt_ids + rs.response_ids + rs.sample_params.max_tokens = rs.sample_params.max_tokens - len(rs.response_ids) + else: + for rs in data: + rs.response = "" + rs.response_ids = [] + rs.logprobs = [] + if rs.routed_experts is not None: + del rs.routed_experts + rs.routed_experts = None + rs.reward = None + rs.staleness = 0 + from .rollout_state import Status + rs.state = Status.INIT + return data + + def update(self): + new_buffer = defaultdict(list) + for staleness, rollout_states in self.buffer.items(): + new_staleness = staleness + 1 + new_buffer[new_staleness].extend(rollout_states) + self.buffer = new_buffer + +class CompletedBuffer: + def __init__(self): + self.buffer: Dict[int, List[List[RolloutState]]] = defaultdict(list) + + def add(self, rollout_state: List[RolloutState]): + group_staleness = max([rs.staleness for rs in rollout_state]) + self.buffer[group_staleness].append(rollout_state) + + def pop(self) -> list[RolloutState]: + highest_staleness = max(self.buffer.keys()) + rollout_states = self.buffer[highest_staleness] + return rollout_states.pop(0) + + def update(self): + new_buffer = defaultdict(list) + for staleness, rollout_states in self.buffer.items(): + new_staleness = staleness + 1 + new_buffer[new_staleness].extend(rollout_states) + self.buffer = new_buffer + + @property + def length(self) -> int: + return sum(len(v) for v in self.buffer.values()) + +class Buffer: + # 这个功能独立的作为一个类的想法是:expired buffer 和 aborted buffer 可能会有不同的优先级管理, + # 例如,我们现在是根据版本进行管理,版本越旧的样本越先出队,可能未来还有其他的方式,例如按照长度? + # 同时,将过期的样本的管理进行独立,使代码可读性更高一点 + # 每个变量控制的内容也要更加独立: + # enable_partial_rollout: 下次rollout是否进行拼接 + # tail_batch_candidate_step: 多老的样本才会进入expired buffer + # tail_batch_trigger_size: expired buffer的触发采样大小阈值 + + def __init__(self, + enable_partial_rollout: bool = False, + tail_batch_candidate_step: int = 1, + tail_batch_trigger_size: int = 10): + self.expired_buffer: ExpiredBuffer = ExpiredBuffer() + self.aborted_buffer: AbortedBuffer = AbortedBuffer() + self.completed_buffer: CompletedBuffer = CompletedBuffer() + self.enable_tail_batch = tail_batch_candidate_step > 0 + self.enable_partial_rollout = enable_partial_rollout + self.tail_batch_candidate_step = tail_batch_candidate_step + + def add(self, rollout_state: List[RolloutState]): + # rollout_state的版本管理放在哪里?例如一次权重更新后,对Buffer里所有样本的版本进行一次更新,而不是在这里进行更新 + group_staleness = max([rs.staleness for rs in rollout_state]) + group_states = [rs.state for rs in rollout_state] + if self.enable_tail_batch and group_staleness > self.tail_batch_candidate_step: + self.expired_buffer.add(rollout_state) + elif all(state == Status.COMPLETED for state in group_states): + self.completed_buffer.add(rollout_state) + else: + self.aborted_buffer.add(rollout_state) + + def get_sample_func(self, data_sampler: DataSampler) -> RolloutState: + use_expired = self.enable_tail_batch and len(self.expired_buffer) > 0 + self.update() + + def _sample(): + if use_expired and self.expired_buffer: + return self.expired_buffer.pop() + elif self.aborted_buffer: + return self.aborted_buffer.pop(self.enable_partial_rollout) + else: + return data_sampler.sample_from_dataset() + + return _sample + + def update(self): + if self.enable_partial_rollout: + self.completed_buffer.update() + else: + while self.completed_buffer.length > 0: + state = self.completed_buffer.pop() + self.aborted_buffer.add(state) + self.aborted_buffer.update() + + +class AsyncProxyEnvRuner: + def __init__( + self, + staleness_threshold: float = 0.0, + enable_partial_rollout: bool = False, + tail_batch_trigger_size: int = 0, + tail_batch_candidate_step: int = 0, + ): + self.base_env_runner: SimpleEnvRunner = None + self.buffer = Buffer(enable_partial_rollout, tail_batch_candidate_step, tail_batch_trigger_size) + self.staleness_threshold = staleness_threshold + + def set_base_env_runner(self, base_env_runner: SimpleEnvRunner): + self.base_env_runner = base_env_runner + + # 这个方法应该可以实现所有异步功能的 + async def async_generate_batch(self, + data_sampler: DataSampler, + batch_size: int, + prompt_repeat_k: int, + ): + # 基于当前内部管理的状态,就可以下一次应该从哪个池子中采样 + # 高度内聚功能模块 + last_step_remain_completed_samples = self.buffer.completed_buffer.length + data_concurrency = (1 + self.staleness_threshold) * (batch_size - last_step_remain_completed_samples) + completed_sample_count = 0 + sample_func = self.buffer.get_sample_func(data_sampler, prompt_repeat_k) + + pending_tasks = [] + + for _ in range(last_step_remain_completed_samples): + traj = self.buffer.completed_buffer.pop() + yield traj + + for _ in range(data_concurrency): + task = asyncio.create_task(self.base_env_runner.generate_group(sample_func)) + # task = asyncio.create_task(self.generate_group(data_sampler.sample(), prompt_repeat_k)) + pending_tasks.append(task) + + while completed_sample_count < data_concurrency: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED) + for task in done_tasks: + try: + traj = await task + if traj is not None: + completed_sample_count += 1 + if completed_sample_count <= batch_size: + yield traj + else: + self.buffer.add(traj) + except Exception as e: + print(f"Error in generating trajectory: {e}") + + await self.base_env_runner.rollout_controller.pause() + while len(pending_tasks) > 0: + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) + for task in done_tasks: + try: + abort_traj = await task + self.buffer.add(abort_traj) + except Exception as e: + print(f"Error while pausing task: {e}") + if len(pending_tasks) > 0: + await self.base_env_runner.rollout_controller.pause() + await asyncio.sleep(1) diff --git a/xtuner/v2/rollout_controller.py b/xtuner/v2/rollout_controller.py new file mode 100644 index 000000000..52ff7fca0 --- /dev/null +++ b/xtuner/v2/rollout_controller.py @@ -0,0 +1,25 @@ +from xtuner.v1.ray.rollout import RolloutController as V1RolloutController +from .rollout_state import RolloutState + +# 临时方案 +class RolloutController(V1RolloutController): + + async def generate(self, rollout_state: RolloutState): + + # 简单包一层 + input_ids = rollout_state.tokens + sample_params = rollout_state.sample_params + session_id = rollout_state.session_id + + response = await super().rollout( + input_ids=input_ids, + sample_params=sample_params, + session_id=session_id + ) + + rollout_state.response = response.response + rollout_state.response_ids = response.response_ids + rollout_state.logprobs = response.logprobs + rollout_state.state = response.state + + return rollout_state diff --git a/xtuner/v2/rollout_state.py b/xtuner/v2/rollout_state.py new file mode 100644 index 000000000..4ff873f0f --- /dev/null +++ b/xtuner/v2/rollout_state.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass +from enum import Enum +from xtuner.v1.data_proto.rl_data import SampleParams +from transformers import AutoTokenizer, AutoProcessor, PreTrainedTokenizerBase +from transformers.image_processing_utils import ProcessorMixin +from xtuner.v1.ray.rollout.controller import SampleParams +from typing import Any + + +class Status(Enum): + INIT = "init" + COMPLETED = "completed" + ABORTED = "aborted" + FAILED = "failed" + ARCHIVED = "archived" + EXPIRED = "expired" + SKIPPED = "skipped" + + +@dataclass +class RolloutState: + message: list + tokens: list[int] # 每一次实际输入 + + uid: int + session_id: int | None = None + prompt_ids: list[int] + response: str + response_ids: list[int] # 每一次实际输出,覆盖写 + logprobs: list[float] + routed_experts: list[int] | None = None + reward: float | list[float] | list[dict] | None = None + loss_mask: list[int] | None = None # tokens + response_ids 的长度 + state: Status = Status.INIT + sample_parms: SampleParams | None = None + tools: list | None = None + tool_choice: str | None = None + mm_infer_info: dict[str, Any] + mm_train_info: dict[str, Any] + finish_reason: str | None = None + staleness: int = 0 + extra_fields: dict[str, Any] | None = None + + +def load_tokenizer(name_or_path: str, **kwargs): + return AutoTokenizer.from_pretrained(name_or_path, **kwargs) + + +def load_processor(name_or_path: str, **kwargs): + try: + proc = AutoProcessor.from_pretrained(name_or_path, **kwargs) + except (OSError, ValueError) as e: + proc = None + + # If HF returned a tokenizer, discard it. + if isinstance(proc, PreTrainedTokenizerBase) or not isinstance(proc, ProcessorMixin): + proc = None + + return proc + + +# TODO: 重命名 +# 必须要有一个随着 RolloutState 一起流转的类,否则无法满足扩展性 +class ProcessorUtilState: + def __init__(self, hf_checkpoint, sample_params=SampleParams()) -> None: + # persistent state for the generation process + self.hf_checkpoint = hf_checkpoint + self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(hf_checkpoint, trust_remote_code=True) + self.sample_params = sample_params diff --git a/xtuner/v2/simple_env_runner.py b/xtuner/v2/simple_env_runner.py new file mode 100644 index 000000000..114fa5dc2 --- /dev/null +++ b/xtuner/v2/simple_env_runner.py @@ -0,0 +1,149 @@ + +import asyncio +from xtuner.v1.datasets import DataloaderConfig +from .utils import load_function +from .rollout_controller import RolloutController +from .rollout_state import ProcessorUtilState, RolloutState + +class DataSampler: + def __init__(self, dataloader_config: DataloaderConfig): + self.dataloader_config = dataloader_config + self.dataloader = None + self.dataloader_iter = None + self.cur_epoch = 0 + + def sample_from_dataset(self, prompt_repeat_k: int) -> RolloutState: + try: + data = next(self.dataloader_iter)[0] + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + # 根据 prompt_repeat_k 进行数据扩展 + group_data = [] + for _ in range(prompt_repeat_k): + group_data.append(data) + return group_data + + def resume(self): + pass + + def save(self): + pass + +# 这个类负责所以定义接口,同时提供一个满足大部分需求的同步运行 rollout。支持单轮,多轮,agent 等场景 +# 异步功能我们假设有两套完全不同的实现,则分别继承这个类进行扩展即可 +class SimpleEnvRunner: + def __init__(self, + rollout_controller: RolloutController, + processor_utils_state: ProcessorUtilState | None = None, + judger: callable | None = None, # none 是为了这个 envruner 可以独立运行, 可以是简单的 callable, 也可以是 actor worker + generate_external: callable | None = None, + # 最理想状态是:这个类用户是完全无感的,用于只要基于 simple_env_runner 定制化自己的逻辑后 + # 然后传入类似这个 proxy 类就可以实现一种异步策略,实现解耦目的 + async_proxy_runner = None, # 用于异步场景的代理 runner + ): + self.rollout_controller = rollout_controller + self.judger = judger + self.processor_utils_state = processor_utils_state + + self.generate_external = generate_external + if self.generate_external is not None: + self.generate_external = load_function(self.generate_external) + + self.async_proxy_runner = async_proxy_runner + if self.async_proxy_runner is not None: + # 循环引用,会不会有问题 + self.async_proxy_runner.set_base_env_runner(self) + + def sample_from_dataset(self) -> RolloutState: + try: + data = next(self.dataloader_iter)[0] + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + return data + + # 生成一条样本 + async def generate_single_sample(self, rollout_state: RolloutState) -> RolloutState: + # 默认走最简单的单轮模式 + # 如果有被打断的样本,则有 state 可以表征 + rollout_state = await self.rollout_controller.generate(rollout_state) + + reward = 0.0 + if self.judger is not None: + if asyncio.iscoroutinefunction(self.judger): + reward = await self.judger(rollout_state) + else: + reward = self.run_judger_worker(rollout_state) + rollout_state.reward = reward + return rollout_state + + async def run_judger_worker(self,rollout_state): + # 可能有多个 judge worker,此时就涉及到调度问题 + # 用户可以重载这个方法,自定义自己调度策略。 + reward = self.judger(rollout_state) + return reward + + async def generate(self, rollout_state: RolloutState) -> RolloutState: + if self.generate_external is not None: + return await self.generate_external(rollout_state, self.processor_utils_state, self.rollout_controller, self.judger) + else: + return await self.generate_single_sample(rollout_state) + + # 生成一组样本 + async def generate_group(self, sample_func, prompt_repeat_k: int) -> list[RolloutState]: + pending_tasks = [] + + group_rollout_state = sample_func(prompt_repeat_k) + for rs in range(group_rollout_state): + task = asyncio.create_task(self.generate(rs)) + pending_tasks.append(task) + + generated_states = asyncio.gather(*pending_tasks) + + group_responses = await generated_states + return group_responses + + # 不可打断式生成一批样本,用于同步场景 + async def generate_batch(self, + data_sampler: DataSampler, + batch_size: int, + prompt_repeat_k: int, + ): + data_concurrency = batch_size + sample_func = data_sampler.sample_from_dataset + + pending_tasks = [] + for _ in range(data_concurrency): + task = asyncio.create_task(self.generate_group(sample_func, prompt_repeat_k)) + pending_tasks.append(task) + + completed_sample_count = 0 + while completed_sample_count < data_concurrency: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED) + for task in done_tasks: + try: + traj = await task + if traj is not None: + yield traj + completed_sample_count += 1 + except Exception as e: + print(f"Error in generating trajectory: {e}") + + # 用于可中断生成场景 + async def async_generate_batch(self, + data_sampler: DataSampler, + batch_size: int, + prompt_repeat_k: int, + ): + return await self.async_proxy_runner.async_generate_batch( + data_sampler, + batch_size, + prompt_repeat_k) diff --git a/xtuner/v2/utils.py b/xtuner/v2/utils.py new file mode 100644 index 000000000..05fde3426 --- /dev/null +++ b/xtuner/v2/utils.py @@ -0,0 +1,11 @@ +import importlib + +def load_function(path): + """ + Load a function from a module. + :param path: The path to the function, e.g. "module.submodule.function". + :return: The function object. + """ + module_path, _, attr = path.rpartition(".") + module = importlib.import_module(module_path) + return getattr(module, attr) diff --git a/xtuner/v2/weight_controller.py b/xtuner/v2/weight_controller.py new file mode 100644 index 000000000..16f010119 --- /dev/null +++ b/xtuner/v2/weight_controller.py @@ -0,0 +1,12 @@ + + +class WeightController: + def __init__(self): + pass + + def set_controllers(self, rollout_controller, training_controller): + self.rollout_controller = rollout_controller + self.training_controller = training_controller + + def update_weights(self): + pass