Skip to content

Commit 4f4169a

Browse files
committed
refactor(trainer): move trainer modules from experimental to areal/trainer
Move trainer-related modules to establish a cleaner architecture: - Move PPOTrainer and SFTTrainer from areal/experimental/trainer/ to areal/trainer/ - Move PPO actor/critic from areal/engine/ppo/ to areal/trainer/ppo/ - Move SFT lm_engine from areal/engine/sft/ to areal/trainer/sft/ - Move RW engine from areal/engine/rw/ to areal/trainer/rw/ - Export PPOTrainer and SFTTrainer from top-level areal package This refactoring separates training algorithm concerns (trainer/) from backend infrastructure (engine/), making the codebase more modular. The trainers can now be imported directly via `from areal import PPOTrainer`. Updates all imports across examples, tests, docs, and internal modules.
1 parent 07a80de commit 4f4169a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+110
-299
lines changed

.claude/agents/algorithm-expert.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def reward_fn(
103103

104104
### 5. Loss Computation
105105

106-
Location: `areal/engine/ppo/actor.py`
106+
Location: `areal/trainer/ppo/actor.py`
107107

108108
**PPO Loss:**
109109

@@ -154,7 +154,7 @@ print(f"Clipping rate: {clipped.float().mean():.2%}")
154154
| File | Purpose |
155155
| -------------------------------- | -------------------------- |
156156
| `areal/api/cli_args.py` | PPOActorConfig, NormConfig |
157-
| `areal/engine/ppo/actor.py` | PPO loss computation |
157+
| `areal/trainer/ppo/actor.py` | PPO loss computation |
158158
| `areal/workflow/rlvr.py` | Single-turn workflow |
159159
| `areal/reward/__init__.py` | Reward function registry |
160160
| `docs/algorithms/grpo_series.md` | Algorithm documentation |

.claude/data/pr-review-change-types.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ ______________________________________________________________________
2020

2121
## HIGH Level (Recommend Opus)
2222

23-
| Change Type | File Path Pattern | Code Pattern |
24-
| --------------------- | ----------------------------- | -------------------------------------------------------------------------------- |
25-
| **DISTRIBUTED_COMM** | - | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `dist.` |
26-
| **DTENSOR** | - | `DTensor`, `DeviceMesh`, `Shard(`, `Replicate(`, `Partial(`, `distribute_tensor` |
27-
| **MOE_LAYER** | `moe/` | `expert`, `token_dispatch`, `grouped_mm`, `MoE` |
28-
| **EP_ETP** | - | `ExpertParallel`, `TensorParallel`, `ExpertTensorParallel`, `ep_size`, `etp` |
29-
| **TENSOR_PARALLEL** | - | `ColwiseParallel`, `RowwiseParallel`, `parallelize_module` |
30-
| **SEQUENCE_PARALLEL** | - | `SequenceParallel`, `context_parallel`, `Ulysses`, `cp_size` |
31-
| **ASYNC_CONCURRENT** | - | `async def`, `await`, `asyncio`, `threading.Lock`, `aiofiles` |
32-
| **TRAINER_CORE** | `areal/experimental/trainer/` | `PPOTrainer`, `SFTTrainer`, `trainer.train` |
23+
| Change Type | File Path Pattern | Code Pattern |
24+
| --------------------- | ----------------- | -------------------------------------------------------------------------------- |
25+
| **DISTRIBUTED_COMM** | - | `all_reduce`, `all_gather`, `reduce_scatter`, `all_to_all`, `dist.` |
26+
| **DTENSOR** | - | `DTensor`, `DeviceMesh`, `Shard(`, `Replicate(`, `Partial(`, `distribute_tensor` |
27+
| **MOE_LAYER** | `moe/` | `expert`, `token_dispatch`, `grouped_mm`, `MoE` |
28+
| **EP_ETP** | - | `ExpertParallel`, `TensorParallel`, `ExpertTensorParallel`, `ep_size`, `etp` |
29+
| **TENSOR_PARALLEL** | - | `ColwiseParallel`, `RowwiseParallel`, `parallelize_module` |
30+
| **SEQUENCE_PARALLEL** | - | `SequenceParallel`, `context_parallel`, `Ulysses`, `cp_size` |
31+
| **ASYNC_CONCURRENT** | - | `async def`, `await`, `asyncio`, `threading.Lock`, `aiofiles` |
32+
| **TRAINER_CORE** | `areal/trainer/` | `PPOTrainer`, `SFTTrainer`, `trainer.train` |
3333

3434
## MEDIUM Level (Use Sonnet)
3535

@@ -142,7 +142,7 @@ ______________________________________________________________________
142142

143143
**Trainer Core**:
144144

145-
- `areal/experimental/trainer/`
145+
- `areal/trainer/`
146146

147147
**Training Engine Core** (excludes FSDP/Megatron which have their own categories):
148148

.claude/hooks/check-expert-update.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ check_expert_update() {
4444
fi
4545

4646
# Algorithm related (PPO, GRPO, workflows)
47-
if [[ "$file" == *"areal/engine/ppo/"* ]] || \
47+
if [[ "$file" == *"areal/trainer/ppo/"* ]] || \
4848
[[ "$file" == *"areal/workflow/"* ]] || \
4949
[[ "$file" == *"areal/reward/"* ]]; then
5050
reminder_file="algorithm-expert.md"

AGENTS.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ When unsure, leave a `TODO(agent)` comment and note the constraint in your respo
119119

120120
## Core concepts & extension points
121121

122-
1. **Trainer (`areal/experimental/trainer/`)** – High-level training orchestrator. Use
123-
`PPOTrainer` for RL training or `SFTTrainer` for supervised fine-tuning. See
122+
1. **Trainer (`areal/trainer/`)** – High-level training orchestrator. Use `PPOTrainer`
123+
for RL training or `SFTTrainer` for supervised fine-tuning. See
124124
`examples/math/gsm8k_rl.py` for a complete example:
125125
```python
126-
from areal.experimental.trainer import PPOTrainer
126+
from areal import PPOTrainer
127127
with PPOTrainer(config, train_dataset, valid_dataset) as trainer:
128128
trainer.train(workflow="areal.workflow.rlvr.RLVRWorkflow", ...)
129129
```

areal/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
workflow_context,
1111
current_platform,
1212
)
13+
from .trainer import PPOTrainer, SFTTrainer
1314

1415
__all__ = [
1516
"TrainController",
@@ -18,4 +19,6 @@
1819
"StalenessManager",
1920
"workflow_context",
2021
"current_platform",
22+
"PPOTrainer",
23+
"SFTTrainer",
2124
]

areal/engine/fsdp_engine.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@
121121
from areal.utils.save_load import get_state_dict_from_repo_id_or_path
122122

123123
if TYPE_CHECKING:
124+
from areal.api.cli_args import PPOActorConfig, PPOCriticConfig
124125
from areal.api.scheduler_api import Scheduler
125-
from areal.engine.ppo.actor import PPOActorConfig
126-
from areal.engine.ppo.critic import PPOCriticConfig
127126

128127

129128
@dataclasses.dataclass
@@ -1639,7 +1638,7 @@ class FSDPPPOActor(FSDPEngine):
16391638
"""PPO Actor implementation using FSDP backend."""
16401639

16411640
def __init__(self, config: PPOActorConfig):
1642-
from areal.engine.ppo.actor import PPOActor
1641+
from areal.trainer.ppo.actor import PPOActor
16431642

16441643
super().__init__(config)
16451644
self.actor = PPOActor(config, self)
@@ -1657,7 +1656,7 @@ def ppo_update(self, *args, **kwargs) -> None:
16571656

16581657
@classmethod
16591658
def as_controller(cls, config: PPOActorConfig, scheduler: Scheduler):
1660-
from areal.engine.ppo.actor import PPOActorController
1659+
from areal.trainer.ppo.actor import PPOActorController
16611660

16621661
return PPOActorController(train_engine=cls, config=config, scheduler=scheduler)
16631662

@@ -1666,7 +1665,7 @@ class FSDPPPOCritic(FSDPEngine):
16661665
"""PPO Critic implementation using FSDP backend."""
16671666

16681667
def __init__(self, config: PPOCriticConfig):
1669-
from areal.engine.ppo.critic import PPOCritic
1668+
from areal.trainer.ppo.critic import PPOCritic
16701669

16711670
super().__init__(config)
16721671
self.critic = PPOCritic(config, self)
@@ -1680,7 +1679,7 @@ def ppo_update(self, *args, **kwargs) -> None:
16801679

16811680
@classmethod
16821681
def as_controller(cls, config: PPOCriticConfig, scheduler: Scheduler):
1683-
from areal.engine.ppo.critic import PPOCriticController
1682+
from areal.trainer.ppo.critic import PPOCriticController
16841683

16851684
return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler)
16861685

@@ -1689,7 +1688,7 @@ class FSDPLMEngine(FSDPEngine):
16891688
"""Language model engine for SFT using FSDP backend."""
16901689

16911690
def __init__(self, config: TrainEngineConfig):
1692-
from areal.engine.sft.lm_engine import LMEngine
1691+
from areal.trainer.sft.lm_engine import LMEngine
16931692

16941693
super().__init__(config)
16951694
self.lm_engine = LMEngine(self)
@@ -1702,7 +1701,7 @@ def evaluate_lm(self, data):
17021701

17031702
@classmethod
17041703
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
1705-
from areal.engine.sft.lm_engine import LMController
1704+
from areal.trainer.sft.lm_engine import LMController
17061705

17071706
return LMController(train_engine=cls, config=config, scheduler=scheduler)
17081707

@@ -1713,7 +1712,7 @@ class FSDPRWEngine(FSDPEngine):
17131712
def __init__(self, config: TrainEngineConfig):
17141713
from copy import deepcopy
17151714

1716-
from areal.engine.rw.rw_engine import RWEngine
1715+
from areal.trainer.rw.rw_engine import RWEngine
17171716

17181717
super().__init__(config)
17191718
self.rw_engine = RWEngine(self)
@@ -1731,6 +1730,6 @@ def evaluate_rw(self, data):
17311730

17321731
@classmethod
17331732
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
1734-
from areal.engine.rw.rw_engine import RWController
1733+
from areal.trainer.rw.rw_engine import RWController
17351734

17361735
return RWController(train_engine=cls, config=config, scheduler=scheduler)

areal/engine/megatron_engine.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,8 @@
107107
from areal.utils.seeding import get_seed
108108

109109
if TYPE_CHECKING:
110+
from areal.api.cli_args import PPOActorConfig, PPOCriticConfig
110111
from areal.api.scheduler_api import Scheduler
111-
from areal.engine.ppo.actor import PPOActorConfig
112-
from areal.engine.ppo.critic import PPOCriticConfig
113112

114113

115114
class _MegatronModelList(list):
@@ -1564,7 +1563,7 @@ class MegatronPPOActor(MegatronEngine):
15641563
"""PPO Actor implementation using Megatron backend."""
15651564

15661565
def __init__(self, config: PPOActorConfig):
1567-
from areal.engine.ppo.actor import PPOActor
1566+
from areal.trainer.ppo.actor import PPOActor
15681567

15691568
super().__init__(config)
15701569
self.actor = PPOActor(config, self)
@@ -1582,7 +1581,7 @@ def ppo_update(self, *args, **kwargs) -> None:
15821581

15831582
@classmethod
15841583
def as_controller(cls, config: PPOActorConfig, scheduler: Scheduler):
1585-
from areal.engine.ppo.actor import PPOActorController
1584+
from areal.trainer.ppo.actor import PPOActorController
15861585

15871586
return PPOActorController(train_engine=cls, config=config, scheduler=scheduler)
15881587

@@ -1591,7 +1590,7 @@ class MegatronPPOCritic(MegatronEngine):
15911590
"""PPO Critic implementation using Megatron backend."""
15921591

15931592
def __init__(self, config: PPOCriticConfig):
1594-
from areal.engine.ppo.critic import PPOCritic
1593+
from areal.trainer.ppo.critic import PPOCritic
15951594

15961595
super().__init__(config)
15971596
self.critic = PPOCritic(config, self)
@@ -1605,7 +1604,7 @@ def ppo_update(self, *args, **kwargs) -> None:
16051604

16061605
@classmethod
16071606
def as_controller(cls, config: PPOCriticConfig, scheduler: Scheduler):
1608-
from areal.engine.ppo.critic import PPOCriticController
1607+
from areal.trainer.ppo.critic import PPOCriticController
16091608

16101609
return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler)
16111610

@@ -1614,7 +1613,7 @@ class MegatronLMEngine(MegatronEngine):
16141613
"""Language model engine for SFT using Megatron backend."""
16151614

16161615
def __init__(self, config: TrainEngineConfig):
1617-
from areal.engine.sft.lm_engine import LMEngine
1616+
from areal.trainer.sft.lm_engine import LMEngine
16181617

16191618
super().__init__(config)
16201619
self.lm_engine = LMEngine(self)
@@ -1627,7 +1626,7 @@ def evaluate_lm(self, data):
16271626

16281627
@classmethod
16291628
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
1630-
from areal.engine.sft.lm_engine import LMController
1629+
from areal.trainer.sft.lm_engine import LMController
16311630

16321631
return LMController(train_engine=cls, config=config, scheduler=scheduler)
16331632

@@ -1638,7 +1637,7 @@ class MegatronRWEngine(MegatronEngine):
16381637
def __init__(self, config: TrainEngineConfig):
16391638
from copy import deepcopy
16401639

1641-
from areal.engine.rw.rw_engine import RWEngine
1640+
from areal.trainer.rw.rw_engine import RWEngine
16421641

16431642
super().__init__(config)
16441643
self.rw_engine = RWEngine(self)
@@ -1656,6 +1655,6 @@ def evaluate_rw(self, data):
16561655

16571656
@classmethod
16581657
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
1659-
from areal.engine.rw.rw_engine import RWController
1658+
from areal.trainer.rw.rw_engine import RWController
16601659

16611660
return RWController(train_engine=cls, config=config, scheduler=scheduler)

areal/experimental/engine/archon_engine.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ class ArchonPPOActor(ArchonEngine):
12021202
"""PPO Actor implementation using Archon backend."""
12031203

12041204
def __init__(self, config):
1205-
from areal.engine.ppo.actor import PPOActor
1205+
from areal.trainer.ppo.actor import PPOActor
12061206

12071207
super().__init__(config)
12081208
self.actor = PPOActor(config, self)
@@ -1220,7 +1220,7 @@ def ppo_update(self, *args, **kwargs) -> None:
12201220

12211221
@classmethod
12221222
def as_controller(cls, config, scheduler: Scheduler):
1223-
from areal.engine.ppo.actor import PPOActorController
1223+
from areal.trainer.ppo.actor import PPOActorController
12241224

12251225
return PPOActorController(train_engine=cls, config=config, scheduler=scheduler)
12261226

@@ -1229,7 +1229,7 @@ class ArchonPPOCritic(ArchonEngine):
12291229
"""PPO Critic implementation using Archon backend."""
12301230

12311231
def __init__(self, config):
1232-
from areal.engine.ppo.critic import PPOCritic
1232+
from areal.trainer.ppo.critic import PPOCritic
12331233

12341234
super().__init__(config)
12351235
self.critic = PPOCritic(config, self)
@@ -1243,7 +1243,7 @@ def ppo_update(self, *args, **kwargs) -> None:
12431243

12441244
@classmethod
12451245
def as_controller(cls, config, scheduler: Scheduler):
1246-
from areal.engine.ppo.critic import PPOCriticController
1246+
from areal.trainer.ppo.critic import PPOCriticController
12471247

12481248
return PPOCriticController(train_engine=cls, config=config, scheduler=scheduler)
12491249

@@ -1252,7 +1252,7 @@ class ArchonLMEngine(ArchonEngine):
12521252
"""Archon-based LM Engine for SFT training."""
12531253

12541254
def __init__(self, config: TrainEngineConfig):
1255-
from areal.engine.sft.lm_engine import LMEngine
1255+
from areal.trainer.sft.lm_engine import LMEngine
12561256

12571257
super().__init__(config)
12581258
self.lm_engine = LMEngine(self)
@@ -1265,6 +1265,6 @@ def evaluate_lm(self, data):
12651265

12661266
@classmethod
12671267
def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
1268-
from areal.engine.sft.lm_engine import LMController
1268+
from areal.trainer.sft.lm_engine import LMController
12691269

12701270
return LMController(train_engine=cls, config=config, scheduler=scheduler)

areal/experimental/trainer/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

areal/models/fsdp/ulysses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def ulysses_prepare_inputs(
266266
continue
267267

268268
if value.dim() >= 2 and value.shape[:2] == padded_input_ids.shape[:2]:
269-
# Please refer to ppo_loss_fn() in areal/engine/ppo/critic.py
269+
# Please refer to ppo_loss_fn() in areal/trainer/ppo/critic.py
270270
if key in {"values", "returns", "loss_mask"}:
271271
sliced_value = slice_input_tensor(value, dim=1, padding=True)
272272
inputs[key] = sliced_value.squeeze(0)

0 commit comments

Comments
 (0)