|
| 1 | +import os |
| 2 | +import torch |
| 3 | + |
| 4 | +from xtuner.v1.config import ( |
| 5 | + AdamWConfig, |
| 6 | + FSDPConfig, |
| 7 | + LRConfig, |
| 8 | +) |
| 9 | +from xtuner.v1.datasets import FTDPTokenizeFnConfig |
| 10 | +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig |
| 11 | +from xtuner.v1.loss.ce_loss import CELossConfig |
| 12 | +from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config |
| 13 | +from xtuner.v1.train import TrainerConfig |
| 14 | + |
| 15 | + |
| 16 | +QWEN3_MOE_PATH = os.environ["QWEN3_MOE_PATH"] |
| 17 | +ALPACA_PATH = os.environ["ALPACA_PATH"] |
| 18 | + |
| 19 | + |
| 20 | +moe_cfg = Qwen3MoE30BA3Config() |
| 21 | +optim_cfg = AdamWConfig(lr=6e-05) |
| 22 | +lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6) |
| 23 | +fsdp_cfg = FSDPConfig( |
| 24 | + torch_compile=False, |
| 25 | + cpu_offload=False, |
| 26 | + ep_size=moe_cfg.ep_size, |
| 27 | + tp_size=4, |
| 28 | + recompute_ratio=0.25, |
| 29 | +) |
| 30 | + |
| 31 | +dataset_config = [ |
| 32 | + { |
| 33 | + "dataset": DatasetConfig(name="alpaca", anno_path=ALPACA_PATH, sample_ratio=1.0), |
| 34 | + "tokenize_fn": FTDPTokenizeFnConfig(max_length=16384), |
| 35 | + }, |
| 36 | +] |
| 37 | + |
| 38 | +dataloader_config = DataloaderConfig(pack_max_length=16384) |
| 39 | + |
| 40 | +loss_cfg = CELossConfig(mode="chunk", chunk_size=1024) |
| 41 | + |
| 42 | + |
| 43 | +trainer = TrainerConfig( |
| 44 | + load_from=QWEN3_MOE_PATH, |
| 45 | + model_cfg=moe_cfg, |
| 46 | + optim_cfg=optim_cfg, |
| 47 | + fsdp_cfg=fsdp_cfg, |
| 48 | + sp_size=4, |
| 49 | + dataset_cfg=dataset_config, |
| 50 | + dataloader_cfg=dataloader_config, |
| 51 | + lr_cfg=lr_cfg, |
| 52 | + loss_cfg=loss_cfg, |
| 53 | + tokenizer_path=QWEN3_MOE_PATH, |
| 54 | + global_batch_size=32, |
| 55 | + total_epoch=1, |
| 56 | + work_dir=f"/mnt/hwfile/vc-intern-delivery/qa-llm-cicd/test_output/{os.environ['GITHUB_RUN_ID']}/npu-qwen3-sft-recompute/sft", |
| 57 | + seed=0, |
| 58 | + dist_backend="npu:hccl", |
| 59 | +) |
0 commit comments