Skip to content

Commit 92e40de

Browse files
committed
[Opt] len(params) == 1 case in foreach_all_gather
1 parent b9744e6 commit 92e40de

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

xtuner/v1/ops/comm/foreach_allgather.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,33 +31,47 @@ def foreach_all_gather(
3131
dist.all_gather(global_input_tensor_numels, input_tensor_numels_tensor, group=group)
3232
else:
3333
global_input_tensor_numels = [
34-
torch.tensor([reduce(mul, shape) for shape in param_shapes], dtype=torch.int64, device="cpu")
34+
torch.tensor([reduce(mul, shape, 1) for shape in param_shapes], dtype=torch.int64, device="cpu")
3535
for param_shapes in params_shapes_across_group # each param_shapes represents all params shapes on one rank
3636
]
3737

38-
flatten_copyin_tensor = torch.empty((sum(input_tensor_numels),), dtype=param0.dtype, device=param0.device)
39-
splits_copyin_tensor = torch.split(flatten_copyin_tensor, input_tensor_numels)
40-
torch._foreach_copy_(splits_copyin_tensor, [p.flatten() for p in params])
41-
42-
copyout_size = int(sum(sum(i) for i in global_input_tensor_numels))
43-
flatten_copyout_tensor = torch.empty((copyout_size,), dtype=param0.dtype, device=param0.device)
44-
45-
dist.all_gather_into_tensor(flatten_copyout_tensor, flatten_copyin_tensor, group=group)
46-
copyout_split_size: list[int] = sum([i.tolist() for i in global_input_tensor_numels], [])
47-
splits_copyout_tensor = torch.split(flatten_copyout_tensor, copyout_split_size)
48-
49-
_global_input_tensor_shapes: list[None] | list[list[tuple]] = [None for _ in range(dist.get_world_size(group))]
50-
dist.all_gather_object(_global_input_tensor_shapes, input_tensor_shapes, group=group)
51-
_global_input_tensor_shapes = cast(list[list[tuple]], _global_input_tensor_shapes)
52-
global_input_tensor_shapes: list[tuple] = sum(_global_input_tensor_shapes, [])
53-
54-
gathered_params: list[list[torch.Tensor]] = []
55-
for i in range(len(params)):
56-
single_gathered_params: list[torch.Tensor] = []
57-
for rank in range(dist.get_world_size(group)):
58-
offset = len(params) * rank
59-
origin_shape: tuple = global_input_tensor_shapes[offset + i]
60-
single_gathered_params.append(splits_copyout_tensor[offset + i].view(origin_shape))
61-
gathered_params.append(single_gathered_params)
38+
if len(params) == 1:
39+
param0_shape_except_dim0 = list(param0.shape)[1:]
40+
param0_numel_except_dim0 = param0[0].numel()
41+
# Calculate the size of dimension 0 of the gathered tensor, it's compatible for the case of uneven split
42+
split_dim0_sizes = [t.tolist()[0] // param0_numel_except_dim0 for t in global_input_tensor_numels]
43+
gathered_tensor_dim0_size = sum(split_dim0_sizes)
44+
45+
# all_gather_into_tensor gather different ranks data along dimension 0
46+
gathered_tensor = torch.empty(
47+
(gathered_tensor_dim0_size, *param0_shape_except_dim0), dtype=param0.dtype, device=param0.device
48+
)
49+
dist.all_gather_into_tensor(gathered_tensor, param0, group=group)
50+
return [gathered_tensor.split(split_dim0_sizes, dim=0)]
51+
else:
52+
flatten_copyin_tensor = torch.empty((sum(input_tensor_numels),), dtype=param0.dtype, device=param0.device)
53+
splits_copyin_tensor = torch.split(flatten_copyin_tensor, input_tensor_numels)
54+
torch._foreach_copy_(splits_copyin_tensor, [p.flatten() for p in params])
55+
56+
copyout_size = int(sum(sum(i) for i in global_input_tensor_numels))
57+
flatten_copyout_tensor = torch.empty((copyout_size,), dtype=param0.dtype, device=param0.device)
58+
59+
dist.all_gather_into_tensor(flatten_copyout_tensor, flatten_copyin_tensor, group=group)
60+
copyout_split_size: list[int] = sum([i.tolist() for i in global_input_tensor_numels], [])
61+
splits_copyout_tensor = torch.split(flatten_copyout_tensor, copyout_split_size)
62+
63+
_global_input_tensor_shapes: list[None] | list[list[tuple]] = [None for _ in range(dist.get_world_size(group))]
64+
dist.all_gather_object(_global_input_tensor_shapes, input_tensor_shapes, group=group)
65+
_global_input_tensor_shapes = cast(list[list[tuple]], _global_input_tensor_shapes)
66+
global_input_tensor_shapes: list[tuple] = sum(_global_input_tensor_shapes, [])
67+
68+
gathered_params: list[list[torch.Tensor]] = []
69+
for i in range(len(params)):
70+
single_gathered_params: list[torch.Tensor] = []
71+
for rank in range(dist.get_world_size(group)):
72+
offset = len(params) * rank
73+
origin_shape: tuple = global_input_tensor_shapes[offset + i]
74+
single_gathered_params.append(splits_copyout_tensor[offset + i].view(origin_shape))
75+
gathered_params.append(single_gathered_params)
6276

6377
return gathered_params

0 commit comments

Comments
 (0)