Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .claude/agents/algorithm-expert.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def reward_fn(

### 5. Loss Computation

Location: `areal/engine/ppo/actor.py`
Location: `areal/trainer/ppo/actor.py`

**PPO Loss:**

Expand Down Expand Up @@ -154,7 +154,7 @@ print(f"Clipping rate: {clipped.float().mean():.2%}")
| File | Purpose |
| -------------------------------- | -------------------------- |
| `areal/api/cli_args.py` | PPOActorConfig, NormConfig |
| `areal/engine/ppo/actor.py` | PPO loss computation |
| `areal/trainer/ppo/actor.py` | PPO loss computation |
| `areal/workflow/rlvr.py` | Single-turn workflow |
| `areal/reward/__init__.py` | Reward function registry |
| `docs/algorithms/grpo_series.md` | Algorithm documentation |
Expand Down
22 changes: 11 additions & 11 deletions .claude/data/pr-review-change-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ ______________________________________________________________________

## HIGH Level (Recommend Opus)

| Change Type | File Path Pattern | Code Pattern |
| --------------------- | ----------------------------- | -------------------------------------------------------------------------------- |
| **DISTRIBUTED_COMM** | - | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `dist.` |
| **DTENSOR** | - | `DTensor`, `DeviceMesh`, `Shard(`, `Replicate(`, `Partial(`, `distribute_tensor` |
| **MOE_LAYER** | `moe/` | `expert`, `token_dispatch`, `grouped_mm`, `MoE` |
| **EP_ETP** | - | `ExpertParallel`, `TensorParallel`, `ExpertTensorParallel`, `ep_size`, `etp` |
| **TENSOR_PARALLEL** | - | `ColwiseParallel`, `RowwiseParallel`, `parallelize_module` |
| **SEQUENCE_PARALLEL** | - | `SequenceParallel`, `context_parallel`, `Ulysses`, `cp_size` |
| **ASYNC_CONCURRENT** | - | `async def`, `await`, `asyncio`, `threading.Lock`, `aiofiles` |
| **TRAINER_CORE** | `areal/experimental/trainer/` | `PPOTrainer`, `SFTTrainer`, `trainer.train` |
| Change Type | File Path Pattern | Code Pattern |
| --------------------- | ----------------- | -------------------------------------------------------------------------------- |
| **DISTRIBUTED_COMM** | - | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `dist.` |
| **DTENSOR** | - | `DTensor`, `DeviceMesh`, `Shard(`, `Replicate(`, `Partial(`, `distribute_tensor` |
| **MOE_LAYER** | `moe/` | `expert`, `token_dispatch`, `grouped_mm`, `MoE` |
| **EP_ETP** | - | `ExpertParallel`, `TensorParallel`, `ExpertTensorParallel`, `ep_size`, `etp` |
| **TENSOR_PARALLEL** | - | `ColwiseParallel`, `RowwiseParallel`, `parallelize_module` |
| **SEQUENCE_PARALLEL** | - | `SequenceParallel`, `context_parallel`, `Ulysses`, `cp_size` |
| **ASYNC_CONCURRENT** | - | `async def`, `await`, `asyncio`, `threading.Lock`, `aiofiles` |
| **TRAINER_CORE** | `areal/trainer/` | `PPOTrainer`, `SFTTrainer`, `trainer.train` |

## MEDIUM Level (Use Sonnet)

Expand Down Expand Up @@ -142,7 +142,7 @@ ______________________________________________________________________

**Trainer Core**:

- `areal/experimental/trainer/`
- `areal/trainer/`

**Training Engine Core** (excludes FSDP/Megatron which have their own categories):

Expand Down
2 changes: 1 addition & 1 deletion .claude/hooks/check-expert-update.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ check_expert_update() {
fi

# Algorithm related (PPO, GRPO, workflows)
if [[ "$file" == *"areal/engine/ppo/"* ]] || \
if [[ "$file" == *"areal/trainer/ppo/"* ]] || \
[[ "$file" == *"areal/workflow/"* ]] || \
[[ "$file" == *"areal/reward/"* ]]; then
reminder_file="algorithm-expert.md"
Expand Down
6 changes: 3 additions & 3 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ When unsure, leave a `TODO(agent)` comment and note the constraint in your respo

## Core concepts & extension points

1. **Trainer (`areal/experimental/trainer/`)** – High-level training orchestrator. Use
`PPOTrainer` for RL training or `SFTTrainer` for supervised fine-tuning. See
1. **Trainer (`areal/trainer/`)** – High-level training orchestrator. Use `PPOTrainer`
for RL training or `SFTTrainer` for supervised fine-tuning. See
`examples/math/gsm8k_rl.py` for a complete example:
```python
from areal.experimental.trainer import PPOTrainer
from areal import PPOTrainer
with PPOTrainer(config, train_dataset, valid_dataset) as trainer:
trainer.train(workflow="areal.workflow.rlvr.RLVRWorkflow", ...)
```
Expand Down
3 changes: 3 additions & 0 deletions areal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
workflow_context,
current_platform,
)
from .trainer import PPOTrainer, SFTTrainer

__all__ = [
"TrainController",
Expand All @@ -18,4 +19,6 @@
"StalenessManager",
"workflow_context",
"current_platform",
"PPOTrainer",
"SFTTrainer",
]
19 changes: 9 additions & 10 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@
from areal.utils.save_load import get_state_dict_from_repo_id_or_path

if TYPE_CHECKING:
from areal.api.cli_args import PPOActorConfig, PPOCriticConfig
from areal.api.scheduler_api import Scheduler
from areal.engine.ppo.actor import PPOActorConfig
from areal.engine.ppo.critic import PPOCriticConfig


@dataclasses.dataclass
Expand Down Expand Up @@ -1639,7 +1638,7 @@ class FSDPPPOActor(FSDPEngine):
"""PPO Actor implementation using FSDP backend."""

def __init__(self, config: PPOActorConfig):
from areal.engine.ppo.actor import PPOActor
from areal.trainer.ppo.actor import PPOActor

super().__init__(config)
self.actor = PPOActor(config, self)
Expand All @@ -1657,7 +1656,7 @@ def ppo_update(self, *args, **kwargs) -> None:

@classmethod
def as_controller(cls, config: PPOActorConfig, scheduler: Scheduler):
from areal.engine.ppo.actor import PPOActorController
from areal.trainer.ppo.actor import PPOActorController

return PPOActorController(train_engine=cls, config=config, scheduler=scheduler)

Expand All @@ -1666,7 +1665,7 @@ class FSDPPPOCritic(FSDPEngine):
"""PPO Critic implementation using FSDP backend."""

def __init__(self, config: PPOCriticConfig):
from areal.engine.ppo.critic import PPOCritic
from areal.trainer.ppo.critic import PPOCritic

super().__init__(config)
self.critic = PPOCritic(config, self)
Expand All @@ -1680,7 +1679,7 @@ def ppo_update(self, *args, **kwargs) -> None:

@classmethod
def as_controller(cls, config: PPOCriticConfig, scheduler: Scheduler):
from areal.engine.ppo.critic import PPOCriticController
from areal.trainer.ppo.critic import PPOCriticController

return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler)

Expand All @@ -1689,7 +1688,7 @@ class FSDPLMEngine(FSDPEngine):
"""Language model engine for SFT using FSDP backend."""

def __init__(self, config: TrainEngineConfig):
from areal.engine.sft.lm_engine import LMEngine
from areal.trainer.sft.lm_engine import LMEngine

super().__init__(config)
self.lm_engine = LMEngine(self)
Expand All @@ -1702,7 +1701,7 @@ def evaluate_lm(self, data):

@classmethod
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
from areal.engine.sft.lm_engine import LMController
from areal.trainer.sft.lm_engine import LMController

return LMController(train_engine=cls, config=config, scheduler=scheduler)

Expand All @@ -1713,7 +1712,7 @@ class FSDPRWEngine(FSDPEngine):
def __init__(self, config: TrainEngineConfig):
from copy import deepcopy

from areal.engine.rw.rw_engine import RWEngine
from areal.trainer.rw.rw_engine import RWEngine

super().__init__(config)
self.rw_engine = RWEngine(self)
Expand All @@ -1731,6 +1730,6 @@ def evaluate_rw(self, data):

@classmethod
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
from areal.engine.rw.rw_engine import RWController
from areal.trainer.rw.rw_engine import RWController

return RWController(train_engine=cls, config=config, scheduler=scheduler)
19 changes: 9 additions & 10 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@
from areal.utils.seeding import get_seed

if TYPE_CHECKING:
from areal.api.cli_args import PPOActorConfig, PPOCriticConfig
from areal.api.scheduler_api import Scheduler
from areal.engine.ppo.actor import PPOActorConfig
from areal.engine.ppo.critic import PPOCriticConfig


class _MegatronModelList(list):
Expand Down Expand Up @@ -1564,7 +1563,7 @@ class MegatronPPOActor(MegatronEngine):
"""PPO Actor implementation using Megatron backend."""

def __init__(self, config: PPOActorConfig):
from areal.engine.ppo.actor import PPOActor
from areal.trainer.ppo.actor import PPOActor

super().__init__(config)
self.actor = PPOActor(config, self)
Expand All @@ -1582,7 +1581,7 @@ def ppo_update(self, *args, **kwargs) -> None:

@classmethod
def as_controller(cls, config: PPOActorConfig, scheduler: Scheduler):
from areal.engine.ppo.actor import PPOActorController
from areal.trainer.ppo.actor import PPOActorController

return PPOActorController(train_engine=cls, config=config, scheduler=scheduler)

Expand All @@ -1591,7 +1590,7 @@ class MegatronPPOCritic(MegatronEngine):
"""PPO Critic implementation using Megatron backend."""

def __init__(self, config: PPOCriticConfig):
from areal.engine.ppo.critic import PPOCritic
from areal.trainer.ppo.critic import PPOCritic

super().__init__(config)
self.critic = PPOCritic(config, self)
Expand All @@ -1605,7 +1604,7 @@ def ppo_update(self, *args, **kwargs) -> None:

@classmethod
def as_controller(cls, config: PPOCriticConfig, scheduler: Scheduler):
from areal.engine.ppo.critic import PPOCriticController
from areal.trainer.ppo.critic import PPOCriticController

return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler)

Expand All @@ -1614,7 +1613,7 @@ class MegatronLMEngine(MegatronEngine):
"""Language model engine for SFT using Megatron backend."""

def __init__(self, config: TrainEngineConfig):
from areal.engine.sft.lm_engine import LMEngine
from areal.trainer.sft.lm_engine import LMEngine

super().__init__(config)
self.lm_engine = LMEngine(self)
Expand All @@ -1627,7 +1626,7 @@ def evaluate_lm(self, data):

@classmethod
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
from areal.engine.sft.lm_engine import LMController
from areal.trainer.sft.lm_engine import LMController

return LMController(train_engine=cls, config=config, scheduler=scheduler)

Expand All @@ -1638,7 +1637,7 @@ class MegatronRWEngine(MegatronEngine):
def __init__(self, config: TrainEngineConfig):
from copy import deepcopy

from areal.engine.rw.rw_engine import RWEngine
from areal.trainer.rw.rw_engine import RWEngine

super().__init__(config)
self.rw_engine = RWEngine(self)
Expand All @@ -1656,6 +1655,6 @@ def evaluate_rw(self, data):

@classmethod
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
from areal.engine.rw.rw_engine import RWController
from areal.trainer.rw.rw_engine import RWController

return RWController(train_engine=cls, config=config, scheduler=scheduler)
12 changes: 6 additions & 6 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ class ArchonPPOActor(ArchonEngine):
"""PPO Actor implementation using Archon backend."""

def __init__(self, config):
from areal.engine.ppo.actor import PPOActor
from areal.trainer.ppo.actor import PPOActor

super().__init__(config)
self.actor = PPOActor(config, self)
Expand All @@ -1220,7 +1220,7 @@ def ppo_update(self, *args, **kwargs) -> None:

@classmethod
def as_controller(cls, config, scheduler: Scheduler):
from areal.engine.ppo.actor import PPOActorController
from areal.trainer.ppo.actor import PPOActorController

return PPOActorController(train_engine=cls, config=config, scheduler=scheduler)

Expand All @@ -1229,7 +1229,7 @@ class ArchonPPOCritic(ArchonEngine):
"""PPO Critic implementation using Archon backend."""

def __init__(self, config):
from areal.engine.ppo.critic import PPOCritic
from areal.trainer.ppo.critic import PPOCritic

super().__init__(config)
self.critic = PPOCritic(config, self)
Expand All @@ -1243,7 +1243,7 @@ def ppo_update(self, *args, **kwargs) -> None:

@classmethod
def as_controller(cls, config, scheduler: Scheduler):
from areal.engine.ppo.critic import PPOCriticController
from areal.trainer.ppo.critic import PPOCriticController

return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler)

Expand All @@ -1252,7 +1252,7 @@ class ArchonLMEngine(ArchonEngine):
"""Archon-based LM Engine for SFT training."""

def __init__(self, config: TrainEngineConfig):
from areal.engine.sft.lm_engine import LMEngine
from areal.trainer.sft.lm_engine import LMEngine

super().__init__(config)
self.lm_engine = LMEngine(self)
Expand All @@ -1265,6 +1265,6 @@ def evaluate_lm(self, data):

@classmethod
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
from areal.engine.sft.lm_engine import LMController
from areal.trainer.sft.lm_engine import LMController

return LMController(train_engine=cls, config=config, scheduler=scheduler)
4 changes: 0 additions & 4 deletions areal/experimental/trainer/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion areal/models/fsdp/ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def ulysses_prepare_inputs(
continue

if value.dim() >= 2 and value.shape[:2] == padded_input_ids.shape[:2]:
# Please refer to ppo_loss_fn() in areal/engine/ppo/critic.py
# Please refer to ppo_loss_fn() in areal/trainer/ppo/critic.py
if key in {"values", "returns", "loss_mask"}:
sliced_value = slice_input_tensor(value, dim=1, padding=True)
inputs[key] = sliced_value.squeeze(0)
Expand Down
2 changes: 1 addition & 1 deletion areal/tests/experimental/archon/test_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import pytest
import torch

from areal.engine.ppo.actor import grpo_loss_fn
from areal.infra.platforms import current_platform
from areal.tests.experimental.archon.utils import (
ComparisonMetrics,
DualEngineFixture,
compare_tensors,
create_grpo_batch,
)
from areal.trainer.ppo.actor import grpo_loss_fn
from areal.utils.functional import gather_logprobs_entropy

# Skip if no CUDA available
Expand Down
2 changes: 1 addition & 1 deletion areal/tests/grpo/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch.distributed as dist

from areal import PPOTrainer
from areal.api.cli_args import GRPOConfig, load_expr_config
from areal.dataset import get_custom_dataset
from areal.experimental.trainer import PPOTrainer
from areal.reward.gsm8k import gsm8k_reward_fn
from areal.utils import stats_tracker
from areal.utils.hf_utils import load_hf_tokenizer
Expand Down
2 changes: 1 addition & 1 deletion areal/tests/sft/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch.distributed as dist

from areal import SFTTrainer
from areal.api.cli_args import SFTConfig, load_expr_config
from areal.dataset import get_custom_dataset
from areal.experimental.trainer import SFTTrainer
from areal.utils import stats_tracker
from areal.utils.hf_utils import load_hf_tokenizer

Expand Down
Loading