[Feature] Optimizing the communication of FSDP through NVLink SHARP and IBGDA#1415
[Feature] Optimizing the communication of FSDP through NVLink SHARP and IBGDA#1415HIT-cwh wants to merge 2 commits intoInternLM:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces an optimization feature for FSDP (Fully Sharded Data Parallel) training by integrating custom communication operations through NVLink SHARP and IBGDA. The implementation provides optimized all-gather and reduce-scatter collectives that can be enabled via environment variables.
Key Changes
- Added custom communication library integration with
ib_wrapperfor optimized FSDP operations - Implemented buffer management system with n-buffering support for concurrent operations
- Created manager classes to handle all-gather and reduce-scatter operations with double buffering
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 28 comments.
| File | Description |
|---|---|
| xtuner/v1/train/trainer.py | Adds import and initialization of custom communication library with buffer setup |
| xtuner/v1/patch/torch_fsdp_comm.py | New file implementing custom FSDP collectives with buffer managers, symmetric memory allocation, and PyTorch function patching |
| xtuner/v1/patch/init.py | Exports the new patch function for FSDP communication optimization |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is set but ib_wrapper is not available.") | ||
|
|
||
| if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")): | ||
| raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is only supported in PyTorch 2.6 and 2.8.") |
There was a problem hiding this comment.
The error message uses shell-style brace expansion syntax '{AG,RS}' which may be unclear to users. Consider using a more explicit message like 'XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is only supported in PyTorch 2.6 and 2.8.' to improve clarity.
| raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is set but ib_wrapper is not available.") | |
| if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")): | |
| raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is only supported in PyTorch 2.6 and 2.8.") | |
| raise ImportError( | |
| "XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is set but ib_wrapper is not available." | |
| ) | |
| if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")): | |
| raise ImportError( | |
| "XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is only supported in PyTorch 2.6 and 2.8." | |
| ) |
| return torch.empty((size,), dtype=dtype, device=device) | ||
|
|
||
|
|
||
| lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 |
There was a problem hiding this comment.
The torch.library.Library call uses a deprecated API pattern. The comment 'noqa: TOR901' suggests this is a known issue, but the 'FRAGMENT' argument may not be the correct descriptor for the library definition. Verify that this is the intended usage or consider updating to the current PyTorch library API if available.
| lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 | |
| lib = torch.library.Library("fsdp", "DEF") |
| group=reduce_scatter_group, | ||
| from_process_group=allocate_memory_from_process_group, | ||
| ) | ||
| # reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) |
There was a problem hiding this comment.
The commented-out code should be removed rather than left in the codebase. This improves maintainability and reduces confusion about which implementation is active.
| # reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) |
| # self.rdc_scale: dict[int, torch.Tensor] = {} | ||
| self.copy_event_prev: torch.Event | None = None | ||
| self.copy_event: torch.Event | None = None | ||
| self.select_sm = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0)) |
There was a problem hiding this comment.
The variable name 'select_sm' is ambiguous and unclear. Consider renaming it to something more descriptive like 'select_streaming_multiprocessor' or 'use_sm_selection' to improve code readability.
| self.select_sm = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0)) | |
| self.select_streaming_multiprocessor = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0)) |
| if (USE_CUSTOM_AG or USE_CUSTOM_RS) and world_size == dist.get_world_size(): | ||
| recv_bytes = all_gather_input_numel * world_size * all_gather_inputs[0].element_size() | ||
| send_bytes = recv_bytes // world_size | ||
| recv_bytes_aligned = (send_bytes + 127) // 128 * 128 * world_size |
There was a problem hiding this comment.
Magic number 128 is used for alignment calculations without explanation. Consider defining it as a named constant (e.g., MEMORY_ALIGNMENT = 128) at the module level to improve code maintainability and make the purpose clear.
| """Initialize the symmetric buffer manager with n buffering in | ||
| contiguous memory. | ||
|
|
||
| Args: | ||
| default_size (int): Default buffer size in bytes | ||
| alignment (int): Memory alignment requirement for the buffer | ||
| num_buffers (int): Number of buffers for n-buffering |
There was a problem hiding this comment.
The documentation states "Implements n buffering for concurrent operations with contiguous memory" but the actual parameter is named 'num_buffers' with a default of 3. However, the class is initialized with NUM_AG_BUFFERS (2 for AG) and NUM_RS_BUFFERS (1 for RS), which are not 3. The default value in the docstring should match the actual usage or be removed to avoid confusion.
| """Initialize the symmetric buffer manager with n buffering in | |
| contiguous memory. | |
| Args: | |
| default_size (int): Default buffer size in bytes | |
| alignment (int): Memory alignment requirement for the buffer | |
| num_buffers (int): Number of buffers for n-buffering | |
| """Initialize the symmetric buffer manager with n-buffering in | |
| contiguous memory. | |
| Args: | |
| default_size (int): Default buffer size in bytes. | |
| alignment (int): Memory alignment requirement for the buffer. | |
| num_buffers (int): Number of buffers for n-buffering. The actual | |
| value is provided by callers (for example, ``NUM_AG_BUFFERS`` | |
| or ``NUM_RS_BUFFERS``) and may vary depending on usage. |
| ) | ||
|
|
||
|
|
||
| def patch_fsdp_agrs() -> None: |
There was a problem hiding this comment.
The function name 'patch_fsdp_agrs' contains a typo. It should be 'patch_fsdp_args' (with 'args' instead of 'agrs').
| def patch_fsdp_agrs() -> None: | |
| def patch_fsdp_args() -> None: |
| patch_fsdp_agrs() | ||
|
|
There was a problem hiding this comment.
The patch_fsdp_agrs function is called unconditionally during Trainer initialization, but it should only be called when custom communication is actually enabled. Currently, the function checks the environment variables internally, but calling it unconditionally may cause unnecessary module imports and overhead. Consider moving the call inside the conditional block that checks XTUNER_ENABLE_CUSTOM_COMMUNICATION at line 638.
| patch_fsdp_agrs() | |
| if os.getenv("XTUNER_ENABLE_CUSTOM_COMMUNICATION"): | |
| patch_fsdp_agrs() |
| reduce_scatter_input_aligned = reduce_scatter_input | ||
| else: | ||
| reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device) | ||
| reduce_scatter_input_aligned = reduce_scatter_input |
There was a problem hiding this comment.
Variable reduce_scatter_input_aligned is not used.
| reduce_scatter_input_aligned = reduce_scatter_input | |
| else: | |
| reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device) | |
| reduce_scatter_input_aligned = reduce_scatter_input | |
| else: | |
| reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device) |
| # reduce_scatter_input = torch.empty( | ||
| # (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device | ||
| # ) | ||
| reduce_scatter_input_aligned = reduce_scatter_input |
There was a problem hiding this comment.
Variable reduce_scatter_input_aligned is not used.
Step 1 Install nvshmem
One can install nvshmem==3.4.5 following this.
Step 2 Install Optimized Communication Operator Library
Step 3 Train with Optimized FSDP All-Gather and Reduce-Scatter
The timeline will be like this: