Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
mpi_disabled, mpi_isend, mpi_isend_object,
mpi_recv, mpi_recv_object, mpi_send,
mpi_send_object, torch_pybind11_abi)
mpi_send_object, mpi_world_size,
torch_pybind11_abi)
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.bindings.internal.process_group import init_pg
from tensorrt_llm.logger import logger
Expand Down Expand Up @@ -456,6 +457,19 @@ def __init__(self, mapping: Mapping):
self._tp_comm = None
self._pp_comm = None

def _validate_world_size(self):
"""Validate world size before creating sub-communicators to prevent segfaults."""

if ENABLE_MULTI_DEVICE:
actual_world_size = mpi_world_size()
max_rank_needed = self.mapping.world_size

if max_rank_needed > actual_world_size:
raise RuntimeError(
f"Mapping requires world_size={max_rank_needed} "
f"(tp_size={self.mapping.tp_size} * pp_size={self.mapping.pp_size} * cp_size={self.mapping.cp_size}), "
f"but MPI world size is only {actual_world_size}. ")

def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
comm = mpi_comm()
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
Expand Down Expand Up @@ -493,6 +507,7 @@ def recv_object(self, src, tag=0):
@property
def tp_comm(self):
if self._tp_comm is None:
self._validate_world_size()
mapping = self.mapping
new_group = mpi_comm().group.Incl(mapping.tp_group)
self._tp_comm = mpi_comm().Create_group(new_group)
Expand All @@ -501,6 +516,7 @@ def tp_comm(self):
@property
def pp_comm(self):
if self._pp_comm is None:
self._validate_world_size()
mapping = self.mapping
new_group = mpi_comm().group.Incl(mapping.pp_group)
self._pp_comm = mpi_comm().Create_group(new_group)
Expand All @@ -509,6 +525,7 @@ def pp_comm(self):
@property
def cp_comm(self):
if self._cp_comm is None:
self._validate_world_size()
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
self._cp_comm = mpi_comm().Create_group(new_group)
return self._cp_comm
Expand Down