diff --git a/examples/llama3/conf/train/1b_lora.yaml b/examples/llama3/conf/train/1b_lora.yaml new file mode 100644 index 0000000000..6194bee97a --- /dev/null +++ b/examples/llama3/conf/train/1b_lora.yaml @@ -0,0 +1,99 @@ +system: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 2 + disable_bias_linear: True + use_flash_attn: True + sequence_parallel: False + use_distributed_optimizer: True + precision: + bf16: True + attention_softmax_in_fp32: true + accumulate_allreduce_grads_in_fp32: false + num_layers_at_start_in_bf16: 0 + num_layers_at_end_in_bf16: 0 + logging: + log_interval: 1 + tensorboard_log_interval: 1 + wandb_project: ${experiment.exp_name} + wandb_exp_name: ${experiment.exp_name} + checkpoint: + load: xxx + load_adapter: false + ckpt_format: nemo_zarr + auto_detect_ckpt_format: False + save_interval: 1 + finetune: True + + +model: + peft_type: lora + lora_target_modules: ["linear_proj", "linear_qkv"] + lora_dim: 16 + lora_alpha: 32 + lora_dropout: 0.1 + lora_dropout_position: pre + lora_in_init_method: kaiming + lora_out_init_method: zero + + no_gradient_accumulation_fusion: True + use_mcore_models: True + transformer_impl: transformer_engine + num_layers: 16 + hidden_size: 2048 + ffn_hidden_size: 8192 + num_attention_heads: 32 + seq_length: 8192 + group_query_attention: True + num_query_groups: 8 + max_position_embeddings: 8192 + + norm_epsilon: 1e-5 + use_rotary_position_embeddings: True + no_position_embedding: True + swiglu: True + normalization: RMSNorm + position_embedding_type: rope + use_rope_scaling: True + rope_scaling_factor: 32.0 + rotary_base: 500000 + untie_embeddings_and_output_weights: False + init_method_std: 0.02 + attention_dropout: 0.0 + hidden_dropout: 0.0 + clip_grad: 0.3 + + train_iters: 10000 + eval_iters: 100 + eval_interval: 1000 + micro_batch_size: 1 + global_batch_size: 2 + + no_load_optim: True + no_load_rng: True + optimizer: + weight_decay: 1e-4 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-08 + main_grads_dtype: bf16 + main_params_dtype: fp16 + use_distributed_optimizer: True + use_precision_aware_optimizer: True + lr_scheduler: + lr: 0.0004 + min_lr: 0 + lr_decay_style: cosine + seed: 1234 + + +data: + data_path: xxx + dataloader_type: external + split: 1 + num_workers: 1 + tokenizer: + tokenizer_type: Llama3TokenizerFS + tokenizer_path: xxx + vocab_size: 128256 + make_vocab_size_divisible_by: 64 diff --git a/examples/llama3/conf/train/70b_lora.yaml b/examples/llama3/conf/train/70b_lora.yaml new file mode 100644 index 0000000000..2148422bcc --- /dev/null +++ b/examples/llama3/conf/train/70b_lora.yaml @@ -0,0 +1,93 @@ +system: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + disable_bias_linear: True + use_flash_attn: True + sequence_parallel: True + use_distributed_optimizer: True + precision: + bf16: True + attention_softmax_in_fp32: true + accumulate_allreduce_grads_in_fp32: false + num_layers_at_start_in_bf16: 0 + num_layers_at_end_in_bf16: 0 + logging: + log_interval: 1 + tensorboard_log_interval: 1 + wandb_project: ${experiment.exp_name} + wandb_exp_name: ${experiment.exp_name} + checkpoint: + load: xxx + load_adapter: false + ckpt_format: nemo_zarr + save_interval: 1 + auto_detect_ckpt_format: False + save_interval: 20 + finetune: True + +model: + peft_type: lora + lora_target_modules: ["linear_proj", "linear_qkv"] + lora_dim: 16 + lora_alpha: 32 + lora_dropout: 0.1 + lora_dropout_position: pre + lora_in_init_method: kaiming + lora_out_init_method: zero + + no_gradient_accumulation_fusion: True + use_mcore_models: True + transformer_impl: transformer_engine + num_layers: 80 + hidden_size: 8192 + ffn_hidden_size: 28672 + num_attention_heads: 64 + seq_length: 8192 + group_query_attention: True + num_query_groups: 8 + max_position_embeddings: 8192 + + norm_epsilon: 1e-5 + use_rotary_position_embeddings: True + no_position_embedding: True + swiglu: True + normalization: RMSNorm + position_embedding_type: rope + rotary_base: 500000 + untie_embeddings_and_output_weights: True + init_method_std: 0.02 + attention_dropout: 0.0 + hidden_dropout: 0.0 + clip_grad: 0.3 + + train_samples: 10000 + eval_iters: 100 + eval_interval: 1000 + micro_batch_size: 1 + global_batch_size: 2 + + optimizer: + weight_decay: 1e-4 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-08 + main_grads_dtype: bf16 + main_params_dtype: fp16 + use_distributed_optimizer: True + use_precision_aware_optimizer: True + lr_scheduler: + lr: 0.0004 + min_lr: 0 + lr_decay_style: cosine + seed: 1234 + +data: + data_path: xxx + dataloader_type: external + split: 1 + tokenizer: + tokenizer_type: Llama3TokenizerFS + tokenizer_path: xxx + vocab_size: 128256 + make_vocab_size_divisible_by: 64 diff --git a/examples/llama3/conf/train_nemo_llama.yaml b/examples/llama3/conf/train_nemo_llama.yaml new file mode 100644 index 0000000000..1ab89e9c62 --- /dev/null +++ b/examples/llama3/conf/train_nemo_llama.yaml @@ -0,0 +1,29 @@ +defaults: + - train: 1b_lora + - _self_ + +experiment: + exp_name: llama3 + exp_dir: ./outputs_llama3_1b_lora + task: + type: train + backend: megatron + entrypoint: ./flagscale/train/train_nemo_llama.py + runner: + backend: torchrun + nnodes: 1 + nproc_per_node: 2 + hostfile: null + envs: + CUDA_VISIBLE_DEVICES: 0,1 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_APPLY_QK_LAYER_SCALING: 0 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + MKL_SERVICE_FORCE_INTEL: 1 + CUBLAS_WORKSPACE_CONFIG: :4096:8 + NCCL_ALGO: Ring +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch index 9ce6fcab00..5020f3ae1a 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 1120c7529..0a33c32e9 100644 +index 1120c7529..a2c5b17c9 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -67,6 +67,7 @@ def add_megatron_arguments(parser: argparse.ArgumentParser): @@ -22,6 +22,15 @@ index 1120c7529..0a33c32e9 100644 # Custom arguments. if extra_args_provider is not None: +@@ -162,7 +168,7 @@ def validate_model_config_args_from_heterogeneous_config(args): + ) + + n_kv_heads_in_group = [ +- config["attention"]["n_heads_in_group"] for config in hf_config_dict.block_configs ++ config["attention"]["n_heads_in_group"] for config in hf_config_dict.block_configs + if config["attention"]["n_heads_in_group"] is not None + ] + assert all(num == n_kv_heads_in_group[0] for num in n_kv_heads_in_group), "num query head must be consistent across all layers" @@ -368,63 +374,68 @@ def validate_args(args, defaults={}): "legacy model format only supports the 'torch' checkpoint format." update_use_dist_ckpt(args) @@ -35,9 +44,6 @@ index 1120c7529..0a33c32e9 100644 - if args.attention_backend == AttnBackend.local: assert args.spec[0] == 'local' , '--attention-backend local is only supported with --spec local' -+ -+ if not args.enable_hetero: -+ total_model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size - # Pipeline model parallel size. - args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size @@ -49,7 +55,9 @@ index 1120c7529..0a33c32e9 100644 - if args.perform_rl_step: - assert not (args.rl_remove_kv_cache_during_training and args.rl_offload_kv_cache_during_training), \ - "Cannot use both remove-kv-cache-during-training and offload-kv-cache-during-training" -- ++ if not args.enable_hetero: ++ total_model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size + - assert not (args.rl_partial_rollouts and args.rl_remove_kv_cache_during_training), \ - "Cannot use both partial-rollouts and remove-kv-cache-during-training" - @@ -161,18 +169,38 @@ index 1120c7529..0a33c32e9 100644 if args.num_virtual_stages_per_pipeline_rank is None: assert args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None, \ 'please use --num-virtual-stages-per-pipeline-rank to specify virtual pipeline parallel degree when enable uneven pipeline parallelism' -@@ -571,8 +583,9 @@ def validate_args(args, defaults={}): +@@ -571,9 +583,10 @@ def validate_args(args, defaults={}): if args.account_for_loss_in_pipeline_split: num_layers += 1 - assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \ - 'Number of layers should be divisible by the pipeline-model-parallel size' +- + if args.enable_hetero is False: + assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \ + 'Number of layers should be divisible by the pipeline-model-parallel size' - ++ if args.virtual_pipeline_model_parallel_size is not None: if args.overlap_p2p_comm: + assert args.pipeline_model_parallel_size > 1, \ +@@ -633,7 +646,7 @@ def validate_args(args, defaults={}): + args.rank, + ) + if args.fp4_param and not is_te_min_version("2.7.0.dev0"): +- raise ValueError("--fp4-param requires Transformer Engine >= 2.7.0.dev0.") ++ raise ValueError("--fp4-param requires Transformer Engine >= 2.7.0.dev0.") + + if args.overlap_param_gather_with_optimizer_step: + assert args.use_distributed_optimizer, \ +@@ -666,7 +679,7 @@ def validate_args(args, defaults={}): + # FP4 param requires FP4 mode + if args.fp4_param and not args.fp4: + raise ValueError("--fp4-param-gather must be used together with --fp4-format.") +- ++ + # FP4 requires TE >= 2.7.0.dev0 + if args.fp4 and not is_te_min_version("2.7.0.dev0"): + raise ValueError("--fp4-format requires Transformer Engine >= 2.7.0.dev0 for NVFP4BlockScaling support.") @@ -796,12 +809,22 @@ def validate_args(args, defaults={}): # Checks. if args.ffn_hidden_size is None: @@ -202,7 +230,7 @@ index 1120c7529..0a33c32e9 100644 else: args.ffn_hidden_size = 4 * args.hidden_size -@@ -1175,6 +1198,141 @@ def validate_args(args, defaults={}): +@@ -1175,6 +1198,143 @@ def validate_args(args, defaults={}): args.recompute_granularity != 'full' ), 'recompute_granularity must not be full when CUDA Graphs are enabled.' @@ -273,7 +301,8 @@ index 1120c7529..0a33c32e9 100644 + 'PEFT is incompatible with moe_shared_expert_overlap' + assert args.num_experts is None, "PEFT is not tested with MoE currently" + assert args.recompute_method is None and args.recompute_granularity is None and args.recompute_num_layers is None, "PEFT will raise comfilcts with recompute currently" -+ assert args.ckpt_format == 'torch', "PEFT is only tested with torch format checkpoint" ++ assert args.ckpt_format == 'torch' or args.ckpt_format == 'nemo_zarr', "PEFT is tested with torch format checkpoint and nemo_zarr format checkpoint" ++ + + # DualPipeV related + if args.use_dualpipev: @@ -339,12 +368,31 @@ index 1120c7529..0a33c32e9 100644 + 'PEFT is only supported with transformer_engine implementation' + assert args.num_experts is None, "PEFT is not tested with MoE currently" + assert args.recompute_method is None and args.recompute_granularity is None and args.recompute_num_layers is None, "PEFT will raise comfilcts with recompute currently" -+ assert args.ckpt_format == 'torch', "PEFT is only tested with torch format checkpoint" ++ assert args.ckpt_format == 'torch' or args.ckpt_format == "nemo_zarr", "PEFT is tested with torch format checkpoint and nemo_zarr format checkpoint" ++ + # Print arguments. _print_args("arguments", args) -@@ -1585,6 +1743,8 @@ def _add_network_size_args(parser): +@@ -1207,7 +1367,7 @@ def core_transformer_config_from_args(args, config_class=None): + + if args.multi_latent_attention: + config_class = MLATransformerConfig +- ++ + if args.heterogeneous_layers_config_path is not None: + assert not args.multi_latent_attention, "Multi latent attention with heterogeneous layers is not supported." + config_class = HeterogeneousTransformerConfig +@@ -1320,7 +1480,7 @@ def _add_transformer_engine_args(parser): + help='Number of layers at start to construct in bf16 when --first-last-layers-bf16 is enabled.') + group.add_argument('--num-layers-at-end-in-bf16', type=int, default=1, + help='Number of layers at end to construct in bf16 when --first-last-layers-bf16 is enabled.') +- ++ + # FP4 related arguments + group.add_argument('--fp4-format', default=None, + choices=['e2m1'], +@@ -1585,6 +1745,8 @@ def _add_network_size_args(parser): help='Which normalization technique to use.') group.add_argument('--norm-epsilon', type=float, default=1e-5, help='Epsilon for layer norm and RMS norm.') @@ -353,7 +401,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--apply-layernorm-1p', action='store_true', help='Adjust LayerNorm weights such that they are centered ' 'around zero. This improves numerical stability.') -@@ -1608,6 +1768,10 @@ def _add_network_size_args(parser): +@@ -1608,6 +1770,10 @@ def _add_network_size_args(parser): group.add_argument('--glu-linear-offset', type=float, default=0.0, help='Offset term in the GLU activation function: activation_func(x[0]) * (x[1] + offset). ' 'Only used when gated_linear_unit is True') @@ -364,7 +412,13 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--onnx-safe', type=bool, required=False, help='Use workarounds for known problems with ' 'Torch ONNX exporter') -@@ -1820,6 +1984,14 @@ def _add_logging_args(parser): +@@ -1815,11 +1981,19 @@ def _add_logging_args(parser): + help='The wandb entity name. It is useful when ' + 'there are multiple sub-projects in a project. ' + 'https://community.wandb.ai/t/how-do-i-decide-which-account-private-or-team-to-upload-the-run-to/5704 ' +- 'Ignore wandb by default.') ++ 'Ignore wandb by default.') + group.add_argument('--wandb-exp-name', type=str, default='', help='The wandb experiment name.') group.add_argument('--wandb-save-dir', type=str, default='', help='Path to save the wandb results locally.') @@ -379,7 +433,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--logging-level', type=int, default=None, help='Set default logging level') return parser -@@ -1854,6 +2026,15 @@ def _add_regularization_args(parser): +@@ -1854,6 +2028,15 @@ def _add_regularization_args(parser): 'numerical stability') group.add_argument('--sgd-momentum', type=float, default=0.9, help='Momentum factor for sgd') @@ -395,7 +449,16 @@ index 1120c7529..0a33c32e9 100644 return parser -@@ -2001,6 +2182,25 @@ def _add_training_args(parser): +@@ -1863,7 +2046,7 @@ def _add_rl_args(parser): + help="Use the RL training step.") + group.add_argument('--rl-prompts-per-eval', type=int, default=32, + help='Number of prompts to evaluate for for each RL task.' +- 'This evaluation can be very expensive when using environments' ++ 'This evaluation can be very expensive when using environments' + 'that evaluate pass@k so we default to a lower number.') + # TODO(rkirby): allow for "complete" evaluation when --rl-prompts-per-eval is set to -1 + group.add_argument('--grpo-prompts-per-step', type=int, default=32, +@@ -2001,6 +2184,25 @@ def _add_training_args(parser): '"shared_experts": recompute the shared experts in the MoE layer.' '"moe_act", "layernorm", and "mla_up_proj" use output-discarding checkpointing, ' '"core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing.') @@ -421,7 +484,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false', help='If not set, clone the output of the scatter in embedding layer to GC original tensor.', dest='clone_scatter_output_in_embedding') -@@ -2087,6 +2287,10 @@ def _add_training_args(parser): +@@ -2087,6 +2289,10 @@ def _add_training_args(parser): help='Total number of samples to train over all ' 'training runs. Note that either train-iters or ' 'train-samples should be provided.') @@ -432,7 +495,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--exit-interval', type=int, default=None, -@@ -2140,7 +2344,7 @@ def _add_training_args(parser): +@@ -2140,7 +2346,7 @@ def _add_training_args(parser): help='Enable bias only in the QKV linear layers', dest='add_qkv_bias') group.add_argument('--optimizer', type=str, default='adam', @@ -441,7 +504,7 @@ index 1120c7529..0a33c32e9 100644 help='Optimizer function') group.add_argument('--optimizer-cpu-offload', action='store_true', help='Offload optimizer state to CPU') -@@ -2210,6 +2414,10 @@ def _add_training_args(parser): +@@ -2210,6 +2416,10 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--use-te-activation-func', action='store_true', help='Use activation function kernel from Transformer Engine in MLP module.') @@ -452,7 +515,7 @@ index 1120c7529..0a33c32e9 100644 return parser -@@ -2268,11 +2476,26 @@ def _add_learning_rate_args(parser): +@@ -2268,11 +2478,26 @@ def _add_learning_rate_args(parser): 'and initial warmup, the learning rate at each ' 'iteration would be different.') group.add_argument('--lr-decay-style', type=str, default='linear', @@ -480,7 +543,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' ' If None defaults to `--train-iters`') -@@ -2331,6 +2554,8 @@ def _add_checkpointing_args(parser): +@@ -2331,6 +2556,8 @@ def _add_checkpointing_args(parser): group.add_argument('--save-retain-interval', type=int, default=None, help='Number of iterations between retained checkpoints (other' 'checkpoints _except the last checkpoint_ are automatically deleted).') @@ -489,7 +552,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--no-save-optim', action='store_true', default=None, help='Do not save current optimizer.') group.add_argument('--no-save-rng', action='store_true', default=None, -@@ -2380,6 +2605,8 @@ def _add_checkpointing_args(parser): +@@ -2380,6 +2607,8 @@ def _add_checkpointing_args(parser): group.add_argument('--no-use-tokenizer-model-from-checkpoint-args', action='store_false', dest='use_tokenizer_model_from_checkpoint_args', help='If set, do not use tokenizer model path from checkpoint') @@ -498,7 +561,25 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--exit-on-missing-checkpoint', action='store_true', help="If '--load' is set, but checkpoint is not found " "(e.g., path typo), then exit instead of random " -@@ -2541,7 +2768,7 @@ def _add_distributed_args(parser): +@@ -2398,7 +2627,7 @@ def _add_checkpointing_args(parser): + dest='dist_ckpt_format_deprecated', + help='Deprecated: see --ckpt-format.') + group.add_argument('--ckpt-format', default='torch_dist', +- choices=['torch', 'torch_dist', 'zarr', 'torch_dcp', 'fsdp_dtensor'], ++ choices=['torch', 'torch_dist', 'zarr', 'torch_dcp', 'nemo_zarr', 'fsdp_dtensor'], + help='Checkpoint format to use. torch is the format used by torch.save/load.' + ' torch_dist is a megatron built-in distributed checkpointing format.' + ' torch_dcp is the torch.distributed.checkpoint format.' +@@ -2455,6 +2684,8 @@ def _add_checkpointing_args(parser): + group.add_argument('--load-model-opt-format', action='store_true', + help='Load a checkpoint for TensorRT model optimizer (nvidia-modelopt).' + 'This function can also be used to load NeMo .nemo sharded checkpoints.') ++ group.add_argument('--load-adapter', action='store_true', default=False, ++ help='Load a checkpoint for Lora.') + return parser + + +@@ -2541,7 +2772,7 @@ def _add_distributed_args(parser): default=False, help='if set, overlap pipeline parallel communication in warmup and flush', dest='overlap_p2p_comm_warmup_flush') group.add_argument('--distributed-backend', default='nccl', @@ -507,7 +588,7 @@ index 1120c7529..0a33c32e9 100644 help='Which backend to use for distributed training.') group.add_argument('--distributed-timeout-minutes', type=int, default=10, help='Timeout minutes for torch.distributed.') -@@ -2592,6 +2819,11 @@ def _add_distributed_args(parser): +@@ -2592,6 +2823,11 @@ def _add_distributed_args(parser): 'complete it instead. Also turns on ' '--use-cpu-initialization flag. This is for ' 'external DDP manager.' ) @@ -519,18 +600,35 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--account-for-embedding-in-pipeline-split', action='store_true', default=False, help='If set, *input* embedding layer will be treated as a standard transformer' 'layer in the context of partition and placement for pipeline parallelism.') -@@ -2636,6 +2868,10 @@ def _add_distributed_args(parser): +@@ -2600,14 +2836,14 @@ def _add_distributed_args(parser): + 'layer in the context of partition and placement for pipeline parallelism.') + group.add_argument('--use-distributed-optimizer', action='store_true', + help='Use distributed optimizer.') +- group.add_argument('--use-nccl-ub', action='store_true', dest='nccl_ub', ++ group.add_argument('--use-nccl-ub', action='store_true', dest='nccl_ub', + help='Use the userbuffer registration for DP/FSDP communication buffers.' + 'This option will reduce GPU SM usage for the DP/FSDP communication,' + 'which is improving the performance of the overlapped computation.') + group.add_argument('--disable-symmetric-registration', action='store_true', dest='disable_symmetric_registration', + default=False, help='Disable symmetric (window) registration for NCCL userbuffer registration.' + 'This option will force to use conventional (local) userbuffer registration when use-nccl-ub is set.') +- group.add_argument('--use-sharp', action='store_true', ++ group.add_argument('--use-sharp', action='store_true', + help='Required to enable SHARP communication.') + group.add_argument('--sharp-enabled-group', type=str, default=None, + choices=['dp', 'dp_replica'], +@@ -2636,6 +2872,10 @@ def _add_distributed_args(parser): help='If set, keep the fp8 transpose cache when using Megatron FSDP.') group.add_argument('--enable-full-sharding-in-hsdp', action='store_true', help='If set, enable full sharding in megatron-fsdp Hybrid Sharded Data Parallel (HSDP) mode.') + group.add_argument('--use-partial-reduce-for-shared-embedding', action='store_true', + help='Use partial reduce for shared word embedding.') -+ group.add_argument('--no-shared-fs', action='store_true', ++ group.add_argument('--no-shared-fs', action='store_true', + help='Indicate whether not running on a shared file system.') group.add_argument('--num-distributed-optimizer-instances', type=int, default=1, help='Number of Distributed Optimizer copies across Data Parallel domain.') group.add_argument('--use-torch-fsdp2', action='store_true', -@@ -2690,6 +2926,9 @@ def _add_validation_args(parser): +@@ -2690,6 +2930,9 @@ def _add_validation_args(parser): group.add_argument('--eval-interval', type=int, default=1000, help='Interval between running evaluation on ' 'validation set.') @@ -540,7 +638,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument("--test-mode", action="store_true", help='Run all real-time test alongside the experiment.') group.add_argument('--skip-train', action='store_true', default=False, help='If set, bypass the training loop, ' -@@ -2708,6 +2947,8 @@ def _add_tokenizer_args(parser): +@@ -2708,6 +2951,8 @@ def _add_tokenizer_args(parser): 'automatically calculated from vocab-size.') group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file.') @@ -549,15 +647,15 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file.') group.add_argument('--vocab-extra-ids', type=int, default=0, -@@ -2726,8 +2967,17 @@ def _add_tokenizer_args(parser): +@@ -2726,8 +2971,17 @@ def _add_tokenizer_args(parser): 'MultimodalTokenizer', 'NullTokenizer', 'NullMultimodalTokenizer', - 'SFTTokenizer'], + 'SFTTokenizer', + 'AquilaTokenizerFS', -+ 'HFTokenizerFS', -+ 'HFTokenizersTokenizerFS', ++ 'HFTokenizerFS', ++ 'HFTokenizersTokenizerFS', + 'Llama3TokenizerFS', + 'QwenTokenizerFS', + 'Qwen2TokenizerFS', @@ -568,7 +666,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--tokenizer-model', type=str, default=None, help='Sentencepiece tokenizer model.') group.add_argument('--tokenizer-metadata', type=str, default=None, -@@ -2768,6 +3018,11 @@ def _add_data_args(parser): +@@ -2768,6 +3022,11 @@ def _add_data_args(parser): group.add_argument('--valid-data-path', nargs='*', default=None, help='The weight and prefix list for an independent validation dataset. ' 'Follows the same pattern rules as --data-path.') @@ -580,7 +678,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--test-data-path', nargs='*', default=None, help='The weight and prefix list for an independent test dataset. ' 'Follows the same pattern rules as --data-path.') -@@ -2816,11 +3071,18 @@ def _add_data_args(parser): +@@ -2816,11 +3075,18 @@ def _add_data_args(parser): 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') @@ -599,7 +697,7 @@ index 1120c7529..0a33c32e9 100644 group.add_argument('--object-storage-cache-path', type=str, default=None, help='Path to cache index files when using s3 or msc dataloader') group.add_argument('--mid-level-dataset-surplus', type=float, default=0.005, -@@ -2897,6 +3159,19 @@ def _add_biencoder_args(parser): +@@ -2897,6 +3163,19 @@ def _add_biencoder_args(parser): return parser @@ -619,7 +717,7 @@ index 1120c7529..0a33c32e9 100644 def _add_vision_args(parser): group = parser.add_argument_group(title="vision") -@@ -2967,6 +3242,8 @@ def _add_vision_args(parser): +@@ -2967,6 +3246,8 @@ def _add_vision_args(parser): help='Whether to layer normalize the q and k attention embeddings.') group.add_argument('--qk-l2-norm', action='store_true', help='Use llama 4 qk l2 norm') @@ -628,8 +726,24 @@ index 1120c7529..0a33c32e9 100644 return parser -@@ -3275,3 +3552,98 @@ def _add_sft_args(parser): - group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", +@@ -3120,9 +3401,9 @@ def _add_mla_args(parser): + + def _add_heterogeneous_args(parser): + """ +- Heterogeneous models refer to transformer architectures where individual layers can differ ++ Heterogeneous models refer to transformer architectures where individual layers can differ + in configuration. Specifically: +- - Attention or MLP layers can be replaced with either a linear layer or a no-op ++ - Attention or MLP layers can be replaced with either a linear layer or a no-op + - MLP intermediate dimensions can vary between layers + We use the format of the HuggingFace config files in llama nemotron models to define the architecture. + For example, https://huggingface.co/nvidia/Llama-3_3-Nemotron-Super-49B-v1/resolve/main/config.json +@@ -3272,6 +3553,101 @@ def _add_kitchen_quantization_arguments(parser: argparse.ArgumentParser): + def _add_sft_args(parser): + group = parser.add_argument_group(title='sft') + group.add_argument('--sft', action="store_true", help='Megatron SFT training') +- group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", ++ group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') return parser + @@ -638,11 +752,11 @@ index 1120c7529..0a33c32e9 100644 +def _add_hetero_args(parser): + group = parser.add_argument_group(title="heterogeneous training") + -+ group.add_argument('--enable-hetero', action="store_true", ++ group.add_argument('--enable-hetero', action="store_true", + help='the mode of heterogeneous training') -+ group.add_argument('--hetero-device-types', nargs='*', type=str, default=None, ++ group.add_argument('--hetero-device-types', nargs='*', type=str, default=None, + help='the list of device types: device_type_0 device_type_1 ...') -+ group.add_argument('--hetero-current-device-type', type=str, default=None, ++ group.add_argument('--hetero-current-device-type', type=str, default=None, + help='the current device type') + group.add_argument('--hetero-pipeline-layer-split', nargs='*', type=int, default=None, + help='Incompatible with --num-layers-per-virtual-pipeline-stage for now.' @@ -653,7 +767,7 @@ index 1120c7529..0a33c32e9 100644 + group.add_argument('--expert-tensor-parallel-size-per-process-mesh', nargs='*', type=int, default=None, + help='The number of tensor parallel experts for each process-mesh. The number of the list should be equal to the number of process-meshes.') + group.add_argument('--hetero-use-cpu-communication', action='store_true', help='Use CPU for communication for heterogeneous communication.') -+ ++ + return parser + + @@ -668,7 +782,7 @@ index 1120c7529..0a33c32e9 100644 + +def _add_auto_skip_spiky_loss(parser): + group = parser.add_argument_group(title='auto skip spiky loss') -+ ++ + group.add_argument('--auto-skip-spiky-loss', action='store_true', + help='Automatically skip spiky loss iterations.') + group.add_argument('--spiky-loss-threshold', type=float, default=0.2, diff --git a/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch b/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch index 8e1c68997d..5a1f0d4a32 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch @@ -1,13 +1,21 @@ diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py -index 104fa6882..722859bf6 100644 +index 104fa6882..40cda37d0 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py -@@ -286,12 +286,15 @@ def read_metadata(tracker_filename): +@@ -40,6 +40,7 @@ from . import wandb_utils + from . import ft_integration + + from megatron.core.msc_utils import MultiStorageClientFeature, open_file ++from megatron.training.global_vars import get_wandb_writer + + try: + from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import preprocess_state_dict_for_uneven_dtensor +@@ -286,12 +287,15 @@ def read_metadata(tracker_filename): print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( tracker_filename)) sys.exit() - assert iteration > 0 or release, 'error parsing metadata file {}'.format( -+ # TODO: we use iteration 0 to load checkpoint from other framework. ++ # TODO: we use iteration 0 to load checkpoint from other framework. + # We should remove this after we have a better way to load checkpoint from other framework. + assert iteration >= 0 or release, 'error parsing metadata file {}'.format( tracker_filename) @@ -20,9 +28,56 @@ index 104fa6882..722859bf6 100644 torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) max_iter = iters_cuda[0].item() -@@ -692,6 +695,28 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati +@@ -369,6 +373,13 @@ def _build_sharded_state_dict_metadata(args: Namespace) -> dict: + impossible to enforce a linearly increasing versioning for this whole space. + """ + metadata = {} ++ if args.ckpt_format == "nemo_zarr": ++ if args.use_distributed_optimizer: ++ if args.ckpt_fully_parallel_save: ++ metadata['distrib_optim_sharding_type'] = 'fully_sharded_model_space' ++ else: ++ metadata['distrib_optim_sharding_type'] = 'dp_zero_gather_scatter' ++ return metadata + + if args.use_distributed_optimizer and args.ckpt_format == "fsdp_dtensor": + metadata['distrib_optim_sharding_type'] = 'fsdp_dtensor' +@@ -469,7 +480,11 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati + return_base_dir = (ckpt_type != CheckpointType.LEGACY) + checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel, + tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir) ++ if ckpt_format == "nemo_zarr" and os.path.exists(checkpoint_name) \ ++ and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0): ++ shutil.rmtree(checkpoint_name) + ++ torch.distributed.barrier() + # Save dataloader state if the dataloader supports it (currently only Megatron Energon). + maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None)) + +@@ -519,7 +534,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati + ) + + state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far +- if ckpt_type == CheckpointType.GLOBAL and ckpt_format == "torch_dist": ++ if ckpt_type == CheckpointType.GLOBAL and (ckpt_format == "torch_dist" or ckpt_format == "nemo_zarr"): + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + # TODO Handle non-empty directories (e.g., after a crash during saving). + ensure_directory_exists(checkpoint_name, check_parent=False) +@@ -529,7 +544,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati + validate_sharding_integrity = not args.ckpt_assume_constant_structure + else: + validate_sharding_integrity = True +- save_strategy = get_default_save_sharded_strategy(args.ckpt_format) ++ save_strategy = get_default_save_sharded_strategy(args.ckpt_format if ckpt_format != "nemo_zarr" else "zarr") + if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist': + save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure + if checkpointing_context is not None and 'load_strategy' in checkpointing_context: +@@ -690,8 +705,35 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati + + # Additional callback for wandb (last rank) if not torch.distributed.is_initialized() \ - or is_last_rank(): +- or is_last_rank(): ++ or is_last_rank() and get_wandb_writer(): def wandb_finalize_fn(): + ######### FlagScale Begin ######### + #NOTE(lizhiyu): The tracker file is created by rank 0 but wandb_finalize_fn is called on the last rank. @@ -33,23 +88,28 @@ index 104fa6882..722859bf6 100644 + timeout_seconds = 600 # 10 minutes + wait_interval_seconds = 5 + max_retries = timeout_seconds // wait_interval_seconds ++ i = 0 ++ for i in range(max_retries): ++ try: ++ if isfile(tracker_file): ++ with open(tracker_file, 'r') as f: ++ content = f.read().strip() ++ if content == str(iteration): ++ break # Success ++ except FileNotFoundError: ++ continue ++ finally: ++ print(f'WandB finalization waiting for the tracker file {tracker_file} to update...') ++ pytime.sleep(wait_interval_seconds) + -+ for _ in range(max_retries): -+ if isfile(tracker_file): -+ with open(tracker_file, 'r') as f: -+ content = f.read().strip() -+ if content == str(iteration): -+ break # Success -+ print(f'WandB finalization waiting for the tracker file {tracker_file} to update...') -+ pytime.sleep(wait_interval_seconds) -+ else: ++ if i == max_retries: + # This block executes if the loop completes without a `break`. + raise RuntimeError(f"Timed out waiting for tracker file {tracker_file} to be updated for iteration {iteration} after {timeout_seconds} seconds.") + ######### FlagScale End ######### wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration) if args.async_save: assert async_save_request is not None -@@ -774,9 +799,7 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path) +@@ -774,9 +816,7 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path) torch.distributed.barrier(group=mpu.get_data_parallel_group()) @@ -60,7 +120,25 @@ index 104fa6882..722859bf6 100644 torch.distributed.barrier(group=mpu.get_data_parallel_group()) dataloader_save_dict = {} -@@ -1239,6 +1262,10 @@ def load_args_from_checkpoint( +@@ -801,7 +841,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler, + if len(model) > 1: + key = f"model{i}" + +- if args.ckpt_format == "torch_dist": ++ if args.ckpt_format == "torch_dist" or args.ckpt_format == "nemo_zarr": + model_sd = model[i].sharded_state_dict(**(model_sd_kwargs or {})) + else: # torch, torch_dcp, fsdp_dtensor + model_sd = model[i].state_dict_for_save_checkpoint() +@@ -813,7 +853,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler, + if optimizer is not None and not optimizer.is_stub_optimizer: + optimizer_sd = None + +- if args.ckpt_format == "torch_dist": ++ if args.ckpt_format == "torch_dist" or args.ckpt_format == "nemo_zarr": + optimizer_sd = optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {})) + elif args.ckpt_format == "fsdp_dtensor": + if optim_sd_kwargs is None: +@@ -1239,6 +1279,10 @@ def load_args_from_checkpoint( checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear') ) @@ -71,7 +149,7 @@ index 104fa6882..722859bf6 100644 def _set_arg(arg_name, old_arg_name=None, force=False): if not force and getattr(args, arg_name, None) is not None: return -@@ -1274,6 +1301,8 @@ def load_args_from_checkpoint( +@@ -1274,6 +1318,8 @@ def load_args_from_checkpoint( _set_arg('add_qkv_bias', force=True) _set_arg('squared_relu', force=True) _set_arg('swiglu', force=True) @@ -80,11 +158,58 @@ index 104fa6882..722859bf6 100644 _set_arg('untie_embeddings_and_output_weights', force=True) _set_arg('apply_layernorm_1p', force=True) _set_arg('normalization', force=True) -@@ -1432,6 +1461,14 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', - mismatch_msg = "(TP, PP) mismatch after resume ({} vs {} from checkpoint)".format( +@@ -1360,7 +1406,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', + strict=strict, + load_arg=load_arg + ) +- ++ + # Since load_modelopt_checkpoint doesn't return iteration count, we need to get it + if torch.distributed.is_initialized(): + tracker_filename = get_checkpoint_tracker_filename(load_dir) +@@ -1372,7 +1418,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', + iteration = 0 + else: + iteration = 0 +- ++ + # We don't have a reliable way to get num_floating_point_operations_so_far from ModelOpt format + return iteration, 0 + +@@ -1390,6 +1436,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', + model = unwrap_model(ddp_model) + + ckpt_format = args.ckpt_format ++ convert_nemo_zarr = False + if args.auto_detect_ckpt_format or ckpt_format == "torch_dist": + state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint( + load_dir, +@@ -1411,11 +1458,20 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', + pass # Not loaded. + else: + raise NotImplementedError(f"checkpoint format {ckpt_format} not supported") ++ elif args.ckpt_format == "nemo_zarr": ++ tracker_filename = get_checkpoint_tracker_filename(load_dir) ++ if isfile(tracker_filename): ++ state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint(load_dir, args, rank0=True, checkpointing_context=checkpointing_context) ++ else: ++ checkpoint_name = os.path.join(load_dir, "weights") ++ state_dict = dist_checkpointing.load_common_state_dict(checkpoint_name) ++ convert_nemo_zarr = True ++ release = False + + load_kwargs = {} + ignore_rng_state = False + ignore_rerun_state = True +- if ckpt_format == "torch_dist": ++ if ckpt_format == "torch_dist" or (ckpt_format == "nemo_zarr" and (not convert_nemo_zarr)): + ckpt_tp_pp = ( + state_dict['args'].tensor_model_parallel_size, + state_dict['args'].pipeline_model_parallel_size, +@@ -1433,6 +1489,14 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', run_tp_pp, ckpt_tp_pp ) -+ + + ########## FlagScale Begin ########## + #Add support for changing parallel strategy from tp/pp to ep for ChainedOptimizer when using dist checkpointing + convert_to_ep = ( @@ -92,10 +217,11 @@ index 104fa6882..722859bf6 100644 + getattr(state_dict['args'], 'expert_model_parallel_size', 1) == 1 + ) + ########## FlagScale End ########## - ++ # Determine if RNG state will be loaded if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune and not args.no_load_rng -@@ -1468,6 +1505,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', + and not getattr(state_dict['args'], 'no_save_rng', False)): +@@ -1468,6 +1532,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', ckpt_tp_pp != run_tp_pp and sharded_sd_metadata['distrib_optim_sharding_type'] not in DistributedOptimizer.checkpoint_fully_reshardable_formats @@ -103,7 +229,7 @@ index 104fa6882..722859bf6 100644 ): raise RuntimeError(f"{mismatch_msg}: not supported for DistributedOptimizer with sharding type" f" {sharded_sd_metadata['distrib_optim_sharding_type']}." -@@ -1481,7 +1519,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', +@@ -1481,7 +1546,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', gen_sd_optim = None gen_sd_opt_param_scheduler = None @@ -112,7 +238,65 @@ index 104fa6882..722859bf6 100644 model_sd_kwargs = dict(metadata=sharded_sd_metadata) # Determine if rerun state will be loaded -@@ -1829,3 +1867,4 @@ def load_biencoder_checkpoint(model, only_query_model=False, +@@ -1543,6 +1608,42 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', + "num_floating_point_operations_so_far": 0, + } + load_kwargs["sharded_state_dict"] = sharded_state_dict ++ elif args.ckpt_format == "nemo_zarr": ++ convert_to_ep = ( ++ getattr(args, 'expert_model_parallel_size', 1) != 1 and ++ getattr(state_dict['args'], 'expert_model_parallel_size', 1) == 1 ++ ) ++ sharded_sd_metadata = dist_checkpointing.load_content_metadata(preloaded_state_dict=state_dict) ++ optim_sd_kwargs = dict(metadata=sharded_sd_metadata, is_loading=True, convert_to_ep=convert_to_ep) ##add tp/pp to ep ++ model_sd_kwargs = dict(metadata=sharded_sd_metadata) ++ gen_sd_rerun_state = None ++ if has_nvidia_modelopt: ++ # if ckpt_type == CheckpointType.LOCAL: ++ # print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.') ++ # elif ckpt_type == CheckpointType.GLOBAL: ++ # restore_modelopt_state(model, state_dict) ++ # else: ++ # restore_sharded_modelopt_state(model, checkpoint_name) ++ restore_modelopt_state(model, state_dict) ++ ++ gen_sd_optim = None ++ gen_sd_opt_param_scheduler = None ++ if not args.no_load_optim and "optimizer" in state_dict.keys(): ++ gen_sd_optim = optimizer ++ gen_sd_opt_param_scheduler = opt_param_scheduler ++ load_kwargs["sharded_state_dict"] = generate_state_dict( ++ args, model, gen_sd_optim, gen_sd_opt_param_scheduler, rng_state=get_rng_state(args.ckpt_format), ++ optim_sd_kwargs=optim_sd_kwargs, model_sd_kwargs=model_sd_kwargs, ++ rerun_state=gen_sd_rerun_state ++ ) ++ if not isfile(tracker_filename): ++ load_kwargs["sharded_state_dict"]["state_dict"] = {"module." + k : v for k, v in load_kwargs["sharded_state_dict"]["model"].items()} ++ load_kwargs["sharded_state_dict"].pop("model") ++ for v in load_kwargs["sharded_state_dict"]["state_dict"].values(): ++ v.key = "module." + v.key ++ if not args.load_adapter: ++ load_kwargs["sharded_state_dict"]["state_dict"] = {k:v for k, v in load_kwargs["sharded_state_dict"]["state_dict"].items() if not (".adapter." in k or k.endswith(".adapters"))} ++ + elif args.ckpt_format == "fsdp_dtensor": + reader = FileSystemReader(get_load_checkpoint_path_by_args(args)) + try: +@@ -1578,7 +1679,13 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', + iteration=1, + ) + +- state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint( ++ if args.ckpt_format == "nemo_zarr" and convert_nemo_zarr: ++ ckpt_type = CheckpointType.GLOBAL ++ state_dict = dist_checkpointing.load(load_kwargs["sharded_state_dict"], checkpoint_name, None, strict=args.dist_ckpt_strictness) ++ state_dict["model"] = state_dict.pop("state_dict") ++ state_dict['model'] = {k.replace("module.", "", 1) : v for k, v in state_dict['model'].items()} ++ else: ++ state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint( + load_dir, args, rank0=False, checkpointing_context=checkpointing_context, + **load_kwargs + ) +@@ -1829,3 +1936,4 @@ def load_biencoder_checkpoint(model, only_query_model=False, print(' successfully loaded {}'.format(checkpoint_name)) return model diff --git a/flagscale/train/peft/lora.py b/flagscale/train/peft/lora.py index 0cb9e771f9..c6e0adfff9 100644 --- a/flagscale/train/peft/lora.py +++ b/flagscale/train/peft/lora.py @@ -133,8 +133,8 @@ def load_state_dict_hook_remap_main_model_params( old_keys.append(prefix + f"bias{gemm_id}") new_keys.append(prefix + f"to_wrap.bias{gemm_id}") else: - old_keys = [prefix + "weight", prefix + "bias"] - new_keys = [prefix + "to_wrap.weight", prefix + "to_wrap.bias"] + old_keys = [prefix + "weight", prefix + "bias", prefix + "layer_norm_weight", prefix + "layer_norm_bias", prefix + "_extra_state"] + new_keys = [prefix + "to_wrap.weight", prefix + "to_wrap.bias", prefix + "to_wrap.layer_norm_weight", prefix + "to_wrap.layer_norm_bias", prefix + "to_wrap._extra_state"] for old_key, new_key in zip(old_keys, new_keys): if old_key in state_dict.keys(): diff --git a/flagscale/train/peft/peft.py b/flagscale/train/peft/peft.py index 1fa424ac42..effd925365 100644 --- a/flagscale/train/peft/peft.py +++ b/flagscale/train/peft/peft.py @@ -149,3 +149,15 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): destination=destination, prefix=f'{prefix}adapter.', keep_vars=keep_vars ) return destination + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + sharded_state_dict = {} + # Get state dict of the main module + base_sharded_state_dict = self.to_wrap.sharded_state_dict(prefix=prefix, sharded_offsets=sharded_offsets, metadata=metadata) + # Store adapter state dict under the "adapter" prefix in the destination dict + adapter_sharded_state_dict = self.adapter.sharded_state_dict( + prefix=f'{prefix}adapter.', sharded_offsets=sharded_offsets, metadata=metadata + ) + sharded_state_dict.update(base_sharded_state_dict) + sharded_state_dict.update(adapter_sharded_state_dict) + return sharded_state_dict \ No newline at end of file diff --git a/flagscale/train/train.py b/flagscale/train/train.py index a2ae60d654..f628c0cda3 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -714,8 +714,15 @@ def reorder_inner_param_groups(optimizer_state_dict): if "param_groups" not in inner_optimizer: return param_groups = inner_optimizer["param_groups"] - key_fn = lambda pg: [pg[key] for key in param_group_identifier_keys] - param_groups.sort(key=key_fn) + def get_sort_key(pg): + result = [] + for key in param_group_identifier_keys: + if key not in pg: + result.append(None) + else: + result.append(pg[key]) + return tuple(result) + param_groups.sort(key=get_sort_key) inner_optimizer["param_groups"] = param_groups optimizer_state_dict = preprocessed_common_state_dict['optimizer'] diff --git a/flagscale/train/train_nemo_llama.py b/flagscale/train/train_nemo_llama.py new file mode 100644 index 0000000000..27b1193907 --- /dev/null +++ b/flagscale/train/train_nemo_llama.py @@ -0,0 +1,557 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain and SFT GPT.""" + +import datetime +import os +import torch + +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch + +from megatron.core import parallel_state +from megatron.training import get_args +from megatron.training import inprocess_restart +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_decoder_block_spec, + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, + get_gpt_mtp_block_spec, +) +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.core.transformer.spec_utils import import_module +from megatron.core.utils import StragglerDetector +from megatron.training import get_args, get_timers, get_tokenizer, print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.utils import ( + get_batch_on_this_cp_rank, +) +from megatron.training.yaml_arguments import core_transformer_config_from_yaml + +import megatron.legacy.model # isort: skip + +# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import + +try: + from megatron.post_training.arguments import add_modelopt_args, modelopt_args_enabled + from megatron.post_training.loss_func import loss_func as loss_func_modelopt + from megatron.post_training.model_provider import model_provider as model_provider_modelopt + + has_nvidia_modelopt = True +except ImportError: + has_nvidia_modelopt = False + +from flagscale.train.extra_valid import extra_valid_datasets_provider +from flagscale.train.train import pretrain +from flagscale.train.global_vars import get_parallel_context +from nemo.collections.llm.gpt.data.core import create_sft_dataset +from megatron.legacy.data.data_samplers import MegatronPretrainingSampler + +import os + + +stimer = StragglerDetector() + + +def model_provider( + pre_process=True, post_process=True, vp_stage: Optional[int] = None, is_dualpipev_first_chunk: Optional[bool] = False, +) -> Union[GPTModel, megatron.legacy.model.GPTModel]: + """Builds the model. + + If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + """ + args = get_args() + + if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt] + return model_provider_modelopt(pre_process, post_process) + + use_te = args.transformer_impl == "transformer_engine" + + if args.record_memory_history: + torch.cuda.memory._record_memory_history( + True, + # keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + # record stack information for the trace events + trace_alloc_record_context=True, + ) + + def oom_observer(device, alloc, device_alloc, device_free): + # snapshot right after an OOM happened + print('saving allocated state during OOM') + snapshot = torch.cuda.memory._snapshot() + from pickle import dump + + dump( + snapshot, + open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'), + ) + + torch._C._cuda_attach_out_of_memory_observer(oom_observer) + + print_rank_0('building GPT model ...') + # Experimental loading arguments from yaml + config = None + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + para_ctx = get_parallel_context() + if para_ctx is not None: + config = para_ctx.get_transformer_config() + + if config is None: + config = core_transformer_config_from_args(args) + + if args.use_legacy_models: + model = megatron.legacy.model.GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + else: # using core models + if args.spec is not None: + transformer_layer_spec = import_module(args.spec) + else: + if args.num_experts: + # Define the decoder block spec + transformer_layer_spec = get_gpt_decoder_block_spec( + config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage, is_dualpipev_first_chunk=is_dualpipev_first_chunk, + ) + elif args.heterogeneous_layers_config_path is not None: + transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te) + else: + # Define the decoder layer spec + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + args.num_experts, + args.moe_grouped_gemm, + args.qk_layernorm, + args.multi_latent_attention, + args.moe_use_legacy_grouped_gemm, + qk_l2_norm=args.qk_l2_norm, + use_kitchen=config.use_kitchen, + ) + else: + transformer_layer_spec = get_gpt_layer_local_spec( + args.num_experts, + args.moe_grouped_gemm, + args.qk_layernorm, + args.multi_latent_attention, + args.moe_use_legacy_grouped_gemm, + normalization=args.normalization, + use_kitchen=config.use_kitchen, + ) + mtp_block_spec = None + if args.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec( + config, transformer_layer_spec, use_transformer_engine=use_te, vp_stage=vp_stage + ) + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + rope_scaling_factor=args.rope_scaling_factor, + mtp_block_spec=mtp_block_spec, + vp_stage=vp_stage, + ) + + return model + + +def get_batch_on_this_tp_rank(data_iterator): + + args = get_args() + + def _broadcast(item): + if item is not None: + torch.distributed.broadcast( + item, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group(), + ) + + if mpu.get_tensor_model_parallel_rank() == 0: + + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + + batch = { + 'tokens': data["tokens"].cuda(non_blocking=True), + 'labels': data["labels"].cuda(non_blocking=True), + 'loss_mask': data["loss_mask"].cuda(non_blocking=True), + 'attention_mask': ( + None + if "attention_mask" not in data + else data["attention_mask"].cuda(non_blocking=True) + ), + 'position_ids': data["position_ids"].cuda(non_blocking=True), + } + + if args.pipeline_model_parallel_size == 1: + _broadcast(batch['tokens']) + _broadcast(batch['labels']) + _broadcast(batch['loss_mask']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + + elif mpu.is_pipeline_first_stage(): + _broadcast(batch['tokens']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + ######### FlagScale Begin ######## + if mpu.get_dualpipev_pipeline_model_parallel_world_size() is not None: + _broadcast(batch['loss_mask']) + _broadcast(batch['labels']) + ######### FlagScale End ######## + + elif mpu.is_pipeline_last_stage(): + # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. + # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need + # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. + if args.mtp_num_layers is not None: + _broadcast(batch['tokens']) + _broadcast(batch['position_ids']) + _broadcast(batch['labels']) + _broadcast(batch['loss_mask']) + _broadcast(batch['attention_mask']) + + else: + + tokens = torch.empty( + (args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + labels = torch.empty( + (args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + loss_mask = torch.empty( + (args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + if args.create_attention_mask_in_dataloader: + attention_mask = torch.empty( + (args.micro_batch_size, 1, args.seq_length, args.seq_length), + dtype=torch.bool, + device=torch.cuda.current_device(), + ) + else: + attention_mask = None + position_ids = torch.empty( + (args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + + if args.pipeline_model_parallel_size == 1: + _broadcast(tokens) + _broadcast(labels) + _broadcast(loss_mask) + _broadcast(attention_mask) + _broadcast(position_ids) + + elif mpu.is_pipeline_first_stage(): + _broadcast(tokens) + _broadcast(attention_mask) + _broadcast(position_ids) + ######### FlagScale Modify ######## + if mpu.get_dualpipev_pipeline_model_parallel_world_size() is not None: + _broadcast(loss_mask) + _broadcast(labels) + else: + labels = None + loss_mask = None + + elif mpu.is_pipeline_last_stage(): + # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. + # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need + # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. + if args.mtp_num_layers is not None: + _broadcast(tokens) + _broadcast(position_ids) + else: + tokens = None + position_ids = None + + _broadcast(labels) + _broadcast(loss_mask) + _broadcast(attention_mask) + + batch = { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + } + + return batch + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and ( + not parallel_state.is_pipeline_last_stage(ignore_virtual=True) + ): + return None, None, None, None, None + + + # get batches based on the TP rank you are on + # tokens, lanbels, loss_mask, position_ids, attention_mask, token_count + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + + +# define spiky loss as a loss that's 10x the max loss observed +SPIKY_LOSS_FACTOR = 10 + + +def loss_func( + loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None +): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + model (GPTModel, optional): The model (can be wrapped) + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt] + return loss_func_modelopt(loss_mask, output_tensor, model=model) + + cp_size = mpu.get_context_parallel_world_size() + losses = output_tensor.view(-1).float() + loss_mask = loss_mask.view(-1).float() + + if cp_size > 1: + loss_all = torch.cat([losses, loss_mask]) + torch.distributed.all_reduce(loss_all, group=mpu.get_context_parallel_group()) + losses, loss_mask = torch.split(loss_all, losses.shape[0], dim=0) + + num_valid_tokens = torch.sum(loss_mask).int() + loss = torch.sum(losses * loss_mask) / num_valid_tokens + loss = torch.where(num_valid_tokens == 0, torch.zeros_like(loss), loss) + + # Check individual rank losses are not NaN prior to DP all-reduce. + rerun_state_machine = get_rerun_state_machine() + if args.check_for_nan_in_loss_and_grad: + rerun_state_machine.validate_result( + result=loss, + rejection_func=torch.isnan, + message="found NaN in local forward loss calculation", + tolerance=0.0, # forward pass calculations are determinisic + fatal=True, + ) + rerun_state_machine.validate_result( + result=loss, + rejection_func=torch.isinf, + message="found Inf in local forward loss calculation", + tolerance=0.0, # forward pass calculations are determinisic + fatal=True, + ) + # Check for spiky loss + if args.check_for_spiky_loss: + rerun_state_machine.validate_result( + result=loss, + rejection_func=partial( + rerun_state_machine.is_unexpectedly_large, + threshold=SPIKY_LOSS_FACTOR, + context="loss", + ), + message="Spiky loss", + tolerance=0.0, # forward pass calculations are determinisic + fatal=False, + ) + reporting_loss = loss.clone().detach().view(1) + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + reporting_loss = reporting_loss / torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) + return (loss, {'lm loss': reporting_loss}) + + +def forward_step(data_iterator, model: GPTModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (GPTModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) + timers('batch-generator').stop() + + with stimer: + if args.use_legacy_models: + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + else: + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask + ) + + # [ModelOpt]: model is needed to access ModelOpt distillation losses + return output_tensor, partial(loss_func, loss_mask, model=model) + + +def is_dataset_built_on_rank(): + return ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + or parallel_state.is_pipeline_last_stage(ignore_virtual=True) + ) and parallel_state.get_tensor_model_parallel_rank() == 0 + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = None, None, None + if is_dataset_built_on_rank(): + train_data_path = os.path.join(args.data_path[0], "train.npy") + validation_data_path = os.path.join(args.data_path[0], "validation.npy") + tokenizer = get_tokenizer() + tokenizer.eos_id = tokenizer.eod_id + kwargs = { + 'return_cu_seqlen': False + } + train_ds = create_sft_dataset( + train_data_path, + tokenizer=tokenizer, + seq_length=args.seq_length, + memmap_workers=1, + seed=args.seed, + is_test=False, + pack_metadata_file_path=None, + pad_cu_seqlens=False, + max_num_samples=train_val_test_num_samples[0], + **kwargs, + ) + train_batch_sampler = MegatronPretrainingSampler( + total_samples=len(train_ds), + consumed_samples=0, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=True, + ) + train_dataloader = torch.utils.data.DataLoader( + dataset=train_ds, + num_workers=args.num_workers, + pin_memory=True, + persistent_workers=True if args.num_workers > 0 else False, + collate_fn=train_ds.collate_fn, + batch_sampler=train_batch_sampler + ) + + valid_ds = create_sft_dataset( + validation_data_path, + tokenizer=tokenizer, + seq_length=args.seq_length, + memmap_workers=1, + seed=args.seed, + is_test=False, + pack_metadata_file_path=None, + pad_cu_seqlens=False, + max_num_samples=train_val_test_num_samples[1], + **kwargs, + ) + valid_batch_sampler = MegatronPretrainingSampler( + total_samples=len(train_ds), + consumed_samples=0, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=True, + ) + valid_dataloader = torch.utils.data.DataLoader( + dataset=valid_ds, + num_workers=args.num_workers, + pin_memory=True, + persistent_workers=True if args.num_workers > 0 else False, + collate_fn=train_ds.collate_fn, + batch_sampler=valid_batch_sampler + ) + + print_rank_0("> finished creating GPT datasets ...") + + return iter(train_dataloader), iter(valid_dataloader), test_ds + return train_ds, valid_ds, test_ds + +# python run.py --config-path examples/llama3/conf --config-name train_nemo_llama action=run +if __name__ == "__main__": + torch.backends.cudnn.deterministic=True + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + # Optionally enable inprocess restart on pretrain + pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + extra_valid_datasets_provider.is_distributed = True ######## FlagScale ######## + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None, + store=store, + extra_valid_dataset_provider=extra_valid_datasets_provider + )