Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,11 @@ def _fsdp_foreach_allgather(
else:
origin_fsdp_size.append(load_spec.shape[self.FSDP_SHARD_DIM])

_fsdp_unsharded_tensor_list = foreach_all_gather(padded_tensor_list, self.fsdp_mesh.get_group())
_fsdp_unsharded_tensor_list = foreach_all_gather(
padded_tensor_list,
self.fsdp_mesh.get_group(),
[[tuple(t.size()) for t in padded_tensor_list]] * self.fsdp_mesh.size(),
)
fsdp_unsharded_tensor_list = []

# Concatenate the tensors along the FSDP shard dim
Expand Down
84 changes: 54 additions & 30 deletions xtuner/v1/ops/comm/foreach_allgather.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import reduce
from operator import mul
from typing import cast

import torch
Expand All @@ -9,6 +11,7 @@
def foreach_all_gather(
params: list[torch.Tensor],
group: dist.ProcessGroup | None,
params_shapes_across_group: list[list[tuple[int, ...]]] | None = None,
) -> list[list[torch.Tensor]]:
if group is None:
group = dist.group.WORLD
Expand All @@ -19,35 +22,56 @@ def foreach_all_gather(
input_tensor_numels = [param.numel() for param in params]
input_tensor_shapes = [param.shape for param in params]

flatten_copyin_tensor = torch.empty((sum(input_tensor_numels),), dtype=param0.dtype, device=param0.device)
splits_copyin_tensor = torch.split(flatten_copyin_tensor, input_tensor_numels)
torch._foreach_copy_(splits_copyin_tensor, [p.flatten() for p in params])

input_tensor_numels_tensor = torch.tensor(input_tensor_numels, dtype=torch.int64, device=param0.device)
global_input_tensor_numels = [
torch.zeros_like(input_tensor_numels_tensor) for _ in range(dist.get_world_size(group))
]

dist.all_gather(global_input_tensor_numels, input_tensor_numels_tensor, group=group)
copyout_size = int(sum(sum(i) for i in global_input_tensor_numels))
flatten_copyout_tensor = torch.empty((copyout_size,), dtype=param0.dtype, device=param0.device)

dist.all_gather_into_tensor(flatten_copyout_tensor, flatten_copyin_tensor, group=group)
copyout_split_size: list[int] = sum([i.tolist() for i in global_input_tensor_numels], [])
splits_copyout_tensor = torch.split(flatten_copyout_tensor, copyout_split_size)

_global_input_tensor_shapes: list[None] | list[list[tuple]] = [None for _ in range(dist.get_world_size(group))]
dist.all_gather_object(_global_input_tensor_shapes, input_tensor_shapes, group=group)
_global_input_tensor_shapes = cast(list[list[tuple]], _global_input_tensor_shapes)
global_input_tensor_shapes: list[tuple] = sum(_global_input_tensor_shapes, [])

gathered_params: list[list[torch.Tensor]] = []
for i in range(len(params)):
single_gathered_params: list[torch.Tensor] = []
for rank in range(dist.get_world_size(group)):
offset = len(params) * rank
origin_shape: tuple = global_input_tensor_shapes[offset + i]
single_gathered_params.append(splits_copyout_tensor[offset + i].view(origin_shape))
gathered_params.append(single_gathered_params)
global_input_tensor_numels: list[torch.Tensor]
if params_shapes_across_group is None:
input_tensor_numels_tensor = torch.tensor(input_tensor_numels, dtype=torch.int64, device=param0.device)
global_input_tensor_numels = [
torch.zeros_like(input_tensor_numels_tensor) for _ in range(dist.get_world_size(group))
]
dist.all_gather(global_input_tensor_numels, input_tensor_numels_tensor, group=group)
else:
global_input_tensor_numels = [
torch.tensor([reduce(mul, shape, 1) for shape in param_shapes], dtype=torch.int64, device="cpu")
for param_shapes in params_shapes_across_group # each param_shapes represents all params shapes on one rank
]

if len(params) == 1:
param0_shape_except_dim0 = list(param0.shape)[1:]
param0_numel_except_dim0 = param0[0].numel()
# Calculate the size of dimension 0 of the gathered tensor, it's compatible for the case of uneven split
split_dim0_sizes = [t.tolist()[0] // param0_numel_except_dim0 for t in global_input_tensor_numels]
gathered_tensor_dim0_size = sum(split_dim0_sizes)

# all_gather_into_tensor gather different ranks data along dimension 0
gathered_tensor = torch.empty(
(gathered_tensor_dim0_size, *param0_shape_except_dim0), dtype=param0.dtype, device=param0.device
)
dist.all_gather_into_tensor(gathered_tensor, param0, group=group)
return [gathered_tensor.split(split_dim0_sizes, dim=0)]
else:
flatten_copyin_tensor = torch.empty((sum(input_tensor_numels),), dtype=param0.dtype, device=param0.device)
splits_copyin_tensor = torch.split(flatten_copyin_tensor, input_tensor_numels)
torch._foreach_copy_(splits_copyin_tensor, [p.flatten() for p in params])

copyout_size = int(sum(sum(i) for i in global_input_tensor_numels))
flatten_copyout_tensor = torch.empty((copyout_size,), dtype=param0.dtype, device=param0.device)

dist.all_gather_into_tensor(flatten_copyout_tensor, flatten_copyin_tensor, group=group)
copyout_split_size: list[int] = sum([i.tolist() for i in global_input_tensor_numels], [])
splits_copyout_tensor = torch.split(flatten_copyout_tensor, copyout_split_size)

_global_input_tensor_shapes: list[None] | list[list[tuple]] = [None for _ in range(dist.get_world_size(group))]
dist.all_gather_object(_global_input_tensor_shapes, input_tensor_shapes, group=group)
_global_input_tensor_shapes = cast(list[list[tuple]], _global_input_tensor_shapes)
global_input_tensor_shapes: list[tuple] = sum(_global_input_tensor_shapes, [])

gathered_params: list[list[torch.Tensor]] = []
for i in range(len(params)):
single_gathered_params: list[torch.Tensor] = []
for rank in range(dist.get_world_size(group)):
offset = len(params) * rank
origin_shape: tuple = global_input_tensor_shapes[offset + i]
single_gathered_params.append(splits_copyout_tensor[offset + i].view(origin_shape))
gathered_params.append(single_gathered_params)

return gathered_params
Loading