Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions examples/v2/single_turn_env_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import ray
import os
from uuid import uuid4

os.environ["XTUNER_USE_LMDEPLOY"] = "1"

from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers
from xtuner.v2.rollout_controller import RolloutController
from xtuner.v2.simple_env_runner import SimpleEnvRunner
from xtuner.v2.rollout_state import ProcessorUtilState, RolloutState
from xtuner.v1.ray.rollout.controller import SampleParams
from xtuner.v1.ray.judger.gsm8k import compute_reward


if __name__ == '__main__':
ray.init(num_cpus=80, ignore_reinit_error=True)

# model_path = '/mnt/shared-storage-user/llmrazor-share/model/intern-s1-mini-hha-fix_tokenizer'
model_path ='/mnt/shared-storage-user/llmrazor-share/model/Qwen3-8B'

resources = AcceleratorResourcesConfig(
accelerator="GPU",
num_workers=1, # 1 or 8
num_cpus_per_worker=12,
cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB
)
pg = AutoAcceleratorWorkers.build_placement_group(resources)

# 2. rollout
rollout_config = RolloutConfig(
device=resources.accelerator,
model_path=model_path,
dtype="bfloat16",
tensor_parallel_size=1,
expert_parallel_size=1,
gpu_memory_utilization=0.75
)

rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)

processor_utils_state = ProcessorUtilState(hf_checkpoint=model_path, sample_params=SampleParams())
simple_env_runner = ray.remote(SimpleEnvRunner).remote(rollout_controller, processor_utils_state=processor_utils_state)

# prompt = [{"content": [{"image_url": {"image_wh": [404, 162],
# "url": "images/test_2.jpg"},
# "type": "image_url"},
# {
# "text": "<IMG_CONTEXT>Find the area of the figure to the nearest tenth. You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \\boxed{}.",
# "type": "text"}],
# "role": "user"}]
# extra_info = {'media_root': '/mnt/shared-storage-user/llmrazor-share/data/geometry3k/'}

prompt = [{"role": "user", "content": 'Calculate 13+24=', "type": "text"}]
data_item = {'prompt': prompt}

input_ids = processor_utils_state.tokenizer.apply_chat_template(data_item["prompt"], return_tensors="pt")["input_ids"][0].tolist()
rollout_state = RolloutState(uid=uuid4().int, tokens=input_ids)

# 生成单条
res1 = ray.get(simple_env_runner.generate.remote(rollout_state))
print("Response from infer:", res1)

# 生成一组
res1 = ray.get(simple_env_runner.generate_group.remote(rollout_state))
print("Response from infer:", res1)

# 生成多条
res1 = ray.get(simple_env_runner.generate_batch.remote(batch_size=64))
print("Response from infer:", res1)

ray.get(rollout_controller.shutdown.remote(), timeout=300)


156 changes: 156 additions & 0 deletions examples/v2/toolcall_env_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import ray
import os
from uuid import uuid4
from copy import deepcopy

os.environ["XTUNER_USE_LMDEPLOY"] = "1"

from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers
from xtuner.v2.rollout_controller import RolloutController
from xtuner.v2.simple_env_runner import SimpleEnvRunner
from xtuner.v2.rollout_state import ProcessorUtilState, RolloutState
from xtuner.v1.ray.rollout.controller import SampleParams
from xtuner.v1.ray.judger.gsm8k import compute_reward


TOOL_CONFIGS = {
"max_turns": 16,
"max_tool_calls": 16,
"tool_concurrency": 32, # Aggressive: 32 concurrent processes
# Python interpreter settings
"python_timeout": 120, # 2 minutes for complex calculations
"python_memory_limit": "4GB", # 4GB per Python process
"python_cpu_limit": 1,
# Memory management settings
"max_memory_usage": 12288, # 12GB total (75% of 16GB)
"cleanup_threshold": 6144, # 6GB
"aggressive_cleanup_threshold": 3072, # 3GB
"force_cleanup_threshold": 9216, # 9GB
}

# 一旦自定义函数,partial rollout 不太好做。不对,好像 partial rollout 很容易就支持了。因为输入的 rollout_state 有 state
# 有了这个 state 就可以判断当前应该走啥流程,是进行 rollout 还是调用工具等等。因为这个是一个状态机,是可以复源的。
async def gsm8k_with_tools_generate(rollout_state: RolloutState,
processor_utils_state: ProcessorUtilState,
rollout_controller:RolloutController,
judger) -> RolloutState:
# 可以自己用 tokens,也可以用 message + tools

# 可以基于输入的 state 判断当前处于什么状态,从而决定接下来的动作。只要输入 state 不是 init 说明肯定开启了 partial rollout
prompt = processor_utils_state.tokenizernizer.apply_chat_template(
rollout_state.messages,
tools=rollout_state.tools if hasattr(rollout_state, "tools") else None,
tokenize=False,
add_generation_prompt=True,
)
prompt_tokens_ids = processor_utils_state.tokenizer.tokenize(prompt)["input_ids"]

response_token_ids = []
loss_masks = []
tool_call_count = 0 # Track actual tool call rounds
rollout_log_probs=[]
init_rollout_state = deepcopy(rollout_state)
for turn in range(TOOL_CONFIGS["max_turns"]):
current_token_ids = prompt_tokens_ids + response_token_ids
rollout_state.tokens = current_token_ids
init_rollout_state.tokens = rollout_state.tokens

# TODO: 需要注意 ray actor 数据传输数据量,原则上他不需要的不应该传入
# 发送给 rollout_controller 的对象只需要追加 tokens 即可,其余内容他不考虑

# 相比于直接发送 post 请求,确实会麻烦一些,学习成本高一些
# 但是由于有一层管理,少掉了给 n 个 url 然后自己路由的麻烦
# TODO: 是否要换成 url post 方式?
rollout_state = await rollout_controller.rollout.remote(rollout_state)

init_rollout_state.state = rollout_state.state

response_token_ids += rollout_state.response_ids
rollout_log_probs += rollout_state.logprobs
loss_masks += [1] * len(rollout_state.response_ids)

if rollout_state.state == RolloutState.State.ABORTED:
# rollout_state 内容转移到 init_rollout_state 中,他是全局保存的
init_rollout_state.state = rollout_state.state
init_rollout_state.response_ids = response_token_ids
init_rollout_state.logprobs = rollout_log_probs
init_rollout_state.loss_mask = loss_masks
return init_rollout_state

# 执行工具
next_obs, done = await execute_predictions(rollout_state.response)
if done:
break

obs_tokens_ids = processor_utils_state.tokenizer(next_obs, add_special_tokens=False)["input_ids"]
response += next_obs
response_token_ids += obs_tokens_ids
rollout_log_probs += [0.0] * len(obs_tokens_ids)
loss_masks += [0] * len(obs_tokens_ids)

if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]:
break

# rollout_state 内容转移到 init_rollout_state 中,他是全局保存的
init_rollout_state.response_ids = response_token_ids
init_rollout_state.logprobs = rollout_log_probs
init_rollout_state.loss_mask = loss_masks
return init_rollout_state


if __name__ == '__main__':

ray.init(num_cpus=80, ignore_reinit_error=True)

# model_path = '/mnt/shared-storage-user/llmrazor-share/model/intern-s1-mini-hha-fix_tokenizer'
model_path ='/mnt/shared-storage-user/llmrazor-share/model/Qwen3-8B'

resources = AcceleratorResourcesConfig(
accelerator="GPU",
num_workers=1, # 1 or 8
num_cpus_per_worker=12,
cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB
)
pg = AutoAcceleratorWorkers.build_placement_group(resources)

# 2. rollout
rollout_config = RolloutConfig(
device=resources.accelerator,
model_path=model_path,
dtype="bfloat16",
tensor_parallel_size=1,
expert_parallel_size=1,
gpu_memory_utilization=0.75
)

rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)

processor_utils_state = ProcessorUtilState(hf_checkpoint=model_path, sample_params=SampleParams())
simple_env_runner = ray.remote(SimpleEnvRunner).remote(rollout_controller,
processor_utils_state=processor_utils_state,
generate_external=gsm8k_with_tools_generate)

prompt = [{"role": "user", "content": 'Calculate 13+24=', "type": "text"}]
data_item = {'prompt': prompt}

input_ids = processor_utils_state.tokenizer.apply_chat_template(data_item["prompt"], return_tensors="pt")["input_ids"][0].tolist()
rollout_state = RolloutState(uid=uuid4().int, tokens=input_ids)

# 生成单条
res1 = ray.get(simple_env_runner.generate.remote(rollout_state))
print("Response from infer:", res1)

# 生成一组
res1 = ray.get(simple_env_runner.generate_group.remote(rollout_state))
print("Response from infer:", res1)

# 生成多条
res1 = ray.get(simple_env_runner.generate_batch.remote(batch_size=64))
print("Response from infer:", res1)

ray.get(rollout_controller.shutdown.remote(), timeout=300)

from xtuner.v2.proxy_async_env_runner import ProxyAsyncEnvRunner


64 changes: 64 additions & 0 deletions xtuner/v2/colocate_rl_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@


# 共卡,是否异步应该不区分的
class ColocateRLTrainer:
def __init__(
self,
config,
env_runner, # 不允许有多少 env runner, 多 task 场景下,应该传入的是一个 composite env runner,由用户自己组织。这样这个类才能简洁通用
weight_controller,

rollout_controller,
training_controller # ppo 算法的所有训练细节都是在这个类里面做,这个 trainer 也感知不到
):
self.config = config
self.env_runner = env_runner
self.weight_controller = weight_controller
self.training_controller = training_controller
self.rollout_controller = rollout_controller
self.replay_buffer = [] # 这个场景下,实际上 list 就足够了
self.batch_size: list[int] = config.batch_size

self.env_runner.set_controller(
rollout_controller=self.rollout_controller,
training_controller=self.training_controller)
self.weight_controller.set_controllers(
rollout_controller=self.rollout_controller,
training_controller=self.training_controller)

self.training_steps_per_epoch=1
self.total_steps = 100

def train_loop(self):
for rollout_id in range(self.total_steps):
self.train_step(rollout_id)

def train_step(self, rollout_id):
self.replay_buffer=[]

# offload train controller to cpu
self.training_controller.offload_to_cpu()
# load rollout_controller
self.rollout_controller.load_to_device()

# Collect rollouts
for trajectory in self.env_runner.generate_batch(self.batch_size):
self.replay_buffer.add(trajectory)

# offload rollout_controller to cpu
self.rollout_controller.offload_to_cpu()
# load train controller
self.training_controller.load_to_device()

# Train the model
for _ in range(self.training_steps_per_epoch):
batch = self.replay_buffer
train_batch= self.convert_rollout_batch_to_train_batch(batch)
self.training_controller.fit(train_batch)

# ipc 和 nccl 应该是两套不同的实现,但是接口一致
self.weight_controller.update_weights()

def convert_rollout_batch_to_train_batch(self, batch):
# 这里假设 rollout batch 和 train batch 是一样的
return batch
57 changes: 57 additions & 0 deletions xtuner/v2/disaggregated_rl_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@


# 非共卡,是否异步理论上是不区分的
class DisaggregatedRLTrainer:
def __init__(
self,
config,
env_runner, # 不允许有多少 env runner, 多 task 场景下,应该传入的是一个 composite env runner,由用户自己组织。这样这个类才能简洁通用
update_weighter,

rollout_controller,
training_controller
):
self.config = config
self.env_runner = env_runner
self.update_weighter = update_weighter
self.training_controller = training_controller
self.rollout_controller = rollout_controller
self.replay_buffer = [] # 这个场景下,实际上 list 就足够了
self.batch_size: list[int] = config.batch_size

self.env_runner.set_controller(
rollout_controller=self.rollout_controller,
training_controller=self.training_controller)
self.update_weighter.set_controllers(
rollout_controller=self.rollout_controller,
training_controller=self.training_controller)

self.training_steps_per_epoch=1
self.total_steps = 100
self.require_batches = config.batch_size // 4

def train_loop(self):
for rollout_id in range(self.total_steps):
self.train_step(rollout_id)

def train_step(self, rollout_id):
self.replay_buffer=[]
# Collect rollouts
for trajectory in self.env_runner.generate_batch(self.batch_size):
self.replay_buffer.add(trajectory)

# 达到指定的 batch 数量后,开始训练
if len(self.replay_buffer)>= self.require_batches:
# Train the model
for _ in range(self.training_steps_per_epoch):
# 从 replay buffer 采样 batch 后,replay_buffer 内部已经训练的数据要清空
batch = self.replay_buffer.pop(self.require_batches)
train_batch= self.convert_rollout_batch_to_train_batch(batch)
self.training_controller.fit(train_batch)

self.sync_weights()
self.save_ckpt()

def convert_rollout_batch_to_train_batch(self, batch):
# 这里假设 rollout batch 和 train batch 是一样的
return batch
Loading