Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def init_parallel_ranks(
and getattr(parallel_config, "tp_comm_bootstrap_backend", None) == 'mpi',
use_te_rng_tracker=getattr(parallel_config, "use_te_rng_tracker", False),
use_sharp=getattr(parallel_config, "use_sharp", False),
create_all_gather_group=getattr(parallel_config, "create_all_gather_group", False),
use_tp_pp_dp_mapping=getattr(parallel_config, "use_tp_pp_dp_mapping", False),
num_distributed_optimizer_instances=getattr(parallel_config, "num_distributed_optimizer_instances", 1),
nccl_communicator_config_path=getattr(parallel_config, "nccl_communicator_config_path", None),
Expand Down Expand Up @@ -130,6 +131,7 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None:
expert_model_parallel_size=app_state.expert_model_parallel_size,
expert_tensor_parallel_size=app_state.expert_tensor_parallel_size,
use_sharp=app_state.use_sharp,
create_all_gather_group=app_state.create_all_gather_group,
order="tp-cp-ep-pp-dp" if app_state.use_tp_pp_dp_mapping else "tp-cp-ep-dp-pp",
num_distributed_optimizer_instances=app_state.num_distributed_optimizer_instances,
nccl_communicator_config_path=app_state.nccl_communicator_config_path,
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def initialize_model_parallel_for_nemo(
num_distributed_optimizer_instances=1,
nccl_communicator_config_path=None,
use_sharp=False,
create_all_gather_group=False,
use_gloo_process_groups: bool = True,
):
"""Initialize model parallel groups in NeMo."""
Expand Down Expand Up @@ -130,6 +131,7 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_comm_backend = pipeline_model_parallel_comm_backend
app_state.use_fp8 = use_fp8
app_state.use_sharp = use_sharp
app_state.create_all_gather_group = create_all_gather_group
app_state.init_mpi_proc_group = init_mpi_proc_group
app_state.expert_tensor_parallel_size = expert_tensor_parallel_size
app_state.num_distributed_optimizer_instances = num_distributed_optimizer_instances
Expand Down
6 changes: 6 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class ParallelismConfig:
num_distributed_optimizer_instances: int = 1
nccl_communicator_config_path: str = None
use_sharp: bool = False
create_all_gather_group: bool = False
pipeline_model_parallel_layout: Optional[Union[str, List[List[str]]]] = None
use_gloo_process_groups: bool = True

Expand Down Expand Up @@ -241,6 +242,8 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
nccl_communicator_config_path (Optional[str]): Path to the yaml file of NCCL communicator configurations.
`min_ctas`, `max_ctas`, and `cga_cluster_size` can be set for each communicator.
use_sharp (bool): Whether to use SHARP. Defaults to False.
create_all_gather_group (bool): Whether to create a separate process group for all-gather operations
to overlap reduce-scatter and all-gather operations. Defaults to False.
pipeline_model_parallel_layout (Optional[Union[str, List[List[str]]]]): The layout of all layers among
different PP and VP stages.
use_gloo_process_groups (bool): Whether to use Gloo process groups. Defaults to True.
Expand Down Expand Up @@ -285,6 +288,7 @@ def __init__(
pipeline_dtype: Optional[torch.dtype] = None,
use_te_rng_tracker: bool = False,
use_sharp: bool = False,
create_all_gather_group: bool = False,
save_ckpt_format: str = "torch_dist",
ckpt_async_save: bool = True,
ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere?
Expand Down Expand Up @@ -349,6 +353,7 @@ def __init__(
self.distrib_optim_fully_reshardable_mem_efficient = distrib_optim_fully_reshardable_mem_efficient
self.use_te_rng_tracker = use_te_rng_tracker
self.use_sharp = use_sharp
self.create_all_gather_group = create_all_gather_group
self._pipeline_dtype = pipeline_dtype
self._setup_optimizers = setup_optimizers
self._init_model_parallel = init_model_parallel
Expand Down Expand Up @@ -1407,6 +1412,7 @@ def parallelism(self) -> ParallelismConfig:
num_distributed_optimizer_instances=self.num_distributed_optimizer_instances,
nccl_communicator_config_path=self.nccl_communicator_config_path,
use_sharp=self.use_sharp,
create_all_gather_group=self.create_all_gather_group,
pipeline_model_parallel_layout=self.pipeline_model_parallel_layout,
use_gloo_process_groups=self.use_gloo_process_groups,
)
Expand Down
17 changes: 17 additions & 0 deletions nemo/utils/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self):
self._init_mpi_proc_gruop = False
self._nccl_communicator_config_path = None
self._use_sharp = False
self._create_all_gather_group = False
self._use_gloo_process_groups = True

self._random_seed = None
Expand Down Expand Up @@ -590,6 +591,22 @@ def use_sharp(self, use_sharp):
"""
self._use_sharp = use_sharp

@property
def create_all_gather_group(self):
"""Property returns whether to create a separate all-gather process group.
Returns:
Whether to create a separate all-gather process group.
"""
return self._create_all_gather_group

@create_all_gather_group.setter
def create_all_gather_group(self, create_all_gather_group):
"""Property sets whether to create a separate all-gather process group.
Args:
create_all_gather_group (bool): Whether to create a separate all-gather process group.
"""
self._create_all_gather_group = create_all_gather_group

@property
def use_gloo_process_groups(self):
"""Property returns whether to use Gloo process groups.
Expand Down
46 changes: 45 additions & 1 deletion tests/lightning/pytorch/strategies/test_megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from unittest.mock import MagicMock, patch

import pytest

from megatron.core.distributed import DistributedDataParallelConfig

from nemo.lightning.pytorch.strategies import MegatronStrategy
Expand Down Expand Up @@ -139,3 +138,48 @@ def test_update_step_kwargs(self):

with pytest.raises(AttributeError):
strategy._update_step_kwargs(1, kwargs={"data_step": None, "forward_step": None}, step_name="first")

def test_create_all_gather_group_default(self):
"""Test that create_all_gather_group defaults to False."""
strategy = MegatronStrategy()
assert strategy.create_all_gather_group == False

def test_create_all_gather_group_enabled(self):
"""Test that create_all_gather_group can be set to True."""
strategy = MegatronStrategy(create_all_gather_group=True)
assert strategy.create_all_gather_group == True

def test_create_all_gather_group_in_parallelism_config(self):
"""Test that create_all_gather_group can be configured via ParallelismConfig."""
import torch

from nemo.lightning.pytorch.strategies.megatron_strategy import ParallelismConfig

parallel_config = ParallelismConfig(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
virtual_pipeline_model_parallel_size=None,
microbatch_group_size_per_vp_stage=1,
context_parallel_size=1,
sequence_parallel=False,
expert_model_parallel_size=1,
moe_extended_tp=False,
pipeline_dtype=torch.float32,
create_all_gather_group=True,
)

assert parallel_config.create_all_gather_group == True

# Test default value
parallel_config_default = ParallelismConfig(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
microbatch_group_size_per_vp_stage=1,
context_parallel_size=1,
sequence_parallel=False,
expert_model_parallel_size=1,
moe_extended_tp=False,
pipeline_dtype=torch.float32,
)
assert parallel_config_default.create_all_gather_group == False
60 changes: 60 additions & 0 deletions tests/lightning/test_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def test_init_model_parallel(mock_mpu, *args):
expert_model_parallel_size=2,
expert_tensor_parallel_size=1,
use_sharp=False,
create_all_gather_group=False,
order="tp-cp-ep-dp-pp",
num_distributed_optimizer_instances=1,
nccl_communicator_config_path=None,
Expand Down Expand Up @@ -240,6 +241,7 @@ def test_init_model_parallel_with_tp_pp_dp(mock_mpu, *args):
expert_model_parallel_size=2,
expert_tensor_parallel_size=1,
use_sharp=False,
create_all_gather_group=False,
order="tp-cp-ep-pp-dp",
num_distributed_optimizer_instances=1,
nccl_communicator_config_path=None,
Expand Down Expand Up @@ -280,6 +282,64 @@ def test_grad_scaler(mock_mpu, *args):
pass


@patch('torch.distributed.is_initialized', return_value=True)
@patch('megatron.core.parallel_state')
def test_init_model_parallel_with_all_gather_group(mock_mpu, *args):
"""Test that create_all_gather_group parameter is properly passed to initialize_model_parallel."""
from nemo.utils import AppState

app_state = AppState()
app_state.model_parallel_size = 1
app_state.tensor_model_parallel_size = 2
app_state.pipeline_model_parallel_size = 1
app_state.pipeline_model_parallel_comm_backend = None
app_state.context_parallel_size = 2
app_state.expert_model_parallel_size = 2
app_state.expert_tensor_parallel_size = 1
app_state.expert_tensor_parallel_rank = 0
app_state.init_mpi_proc_group = False
app_state.tensor_model_parallel_rank = 2
app_state.pipeline_model_parallel_rank = 0
app_state.create_all_gather_group = True

_mpu_tp_2(mock_mpu)
_strategy_lib.init_model_parallel(nn.Identity())

mock_mpu.initialize_model_parallel.assert_called_once_with(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_comm_backend=None,
context_parallel_size=2,
expert_model_parallel_size=2,
expert_tensor_parallel_size=1,
use_sharp=False,
create_all_gather_group=True,
order="tp-cp-ep-dp-pp",
num_distributed_optimizer_instances=1,
nccl_communicator_config_path=None,
create_gloo_process_groups=True,
)


def test_app_state_create_all_gather_group() -> None:
"""Test that AppState properly stores and retrieves create_all_gather_group value."""
from nemo.utils import AppState

app_state = AppState()

# Test default value
assert app_state.create_all_gather_group == False

# Test setter
app_state.create_all_gather_group = True
assert app_state.create_all_gather_group == True

# Test setter with False
app_state.create_all_gather_group = False
assert app_state.create_all_gather_group == False


# TODO @chcui uncomment after fabric API is merged
# @patch('nemo.lightning._strategy_lib.DataLoader', return_value=MagicMock())
# @patch('megatron.core.parallel_state')
Expand Down
Loading