Skip to content

Commit b9744e6

Browse files
committed
[Refactor] add params_shapes_across_group param in foreach_all_gather
1 parent 3c4543a commit b9744e6

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

xtuner/v1/model/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,11 @@ def _fsdp_foreach_allgather(
12171217
else:
12181218
origin_fsdp_size.append(load_spec.shape[self.FSDP_SHARD_DIM])
12191219

1220-
_fsdp_unsharded_tensor_list = foreach_all_gather(padded_tensor_list, self.fsdp_mesh.get_group())
1220+
_fsdp_unsharded_tensor_list = foreach_all_gather(
1221+
padded_tensor_list,
1222+
self.fsdp_mesh.get_group(),
1223+
[[tuple(t.size()) for t in padded_tensor_list]] * self.fsdp_mesh.size(),
1224+
)
12211225
fsdp_unsharded_tensor_list = []
12221226

12231227
# Concatenate the tensors along the FSDP shard dim

xtuner/v1/ops/comm/foreach_allgather.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import reduce
2+
from operator import mul
13
from typing import cast
24

35
import torch
@@ -9,6 +11,7 @@
911
def foreach_all_gather(
1012
params: list[torch.Tensor],
1113
group: dist.ProcessGroup | None,
14+
params_shapes_across_group: list[list[tuple[int, ...]]] | None = None,
1215
) -> list[list[torch.Tensor]]:
1316
if group is None:
1417
group = dist.group.WORLD
@@ -19,16 +22,23 @@ def foreach_all_gather(
1922
input_tensor_numels = [param.numel() for param in params]
2023
input_tensor_shapes = [param.shape for param in params]
2124

25+
global_input_tensor_numels: list[torch.Tensor]
26+
if params_shapes_across_group is None:
27+
input_tensor_numels_tensor = torch.tensor(input_tensor_numels, dtype=torch.int64, device=param0.device)
28+
global_input_tensor_numels = [
29+
torch.zeros_like(input_tensor_numels_tensor) for _ in range(dist.get_world_size(group))
30+
]
31+
dist.all_gather(global_input_tensor_numels, input_tensor_numels_tensor, group=group)
32+
else:
33+
global_input_tensor_numels = [
34+
torch.tensor([reduce(mul, shape) for shape in param_shapes], dtype=torch.int64, device="cpu")
35+
for param_shapes in params_shapes_across_group # each param_shapes represents all params shapes on one rank
36+
]
37+
2238
flatten_copyin_tensor = torch.empty((sum(input_tensor_numels),), dtype=param0.dtype, device=param0.device)
2339
splits_copyin_tensor = torch.split(flatten_copyin_tensor, input_tensor_numels)
2440
torch._foreach_copy_(splits_copyin_tensor, [p.flatten() for p in params])
2541

26-
input_tensor_numels_tensor = torch.tensor(input_tensor_numels, dtype=torch.int64, device=param0.device)
27-
global_input_tensor_numels = [
28-
torch.zeros_like(input_tensor_numels_tensor) for _ in range(dist.get_world_size(group))
29-
]
30-
31-
dist.all_gather(global_input_tensor_numels, input_tensor_numels_tensor, group=group)
3242
copyout_size = int(sum(sum(i) for i in global_input_tensor_numels))
3343
flatten_copyout_tensor = torch.empty((copyout_size,), dtype=param0.dtype, device=param0.device)
3444

0 commit comments

Comments
 (0)