1+ from functools import reduce
2+ from operator import mul
13from typing import cast
24
35import torch
911def foreach_all_gather (
1012 params : list [torch .Tensor ],
1113 group : dist .ProcessGroup | None ,
14+ params_shapes_across_group : list [list [tuple [int , ...]]] | None = None ,
1215) -> list [list [torch .Tensor ]]:
1316 if group is None :
1417 group = dist .group .WORLD
@@ -19,16 +22,23 @@ def foreach_all_gather(
1922 input_tensor_numels = [param .numel () for param in params ]
2023 input_tensor_shapes = [param .shape for param in params ]
2124
25+ global_input_tensor_numels : list [torch .Tensor ]
26+ if params_shapes_across_group is None :
27+ input_tensor_numels_tensor = torch .tensor (input_tensor_numels , dtype = torch .int64 , device = param0 .device )
28+ global_input_tensor_numels = [
29+ torch .zeros_like (input_tensor_numels_tensor ) for _ in range (dist .get_world_size (group ))
30+ ]
31+ dist .all_gather (global_input_tensor_numels , input_tensor_numels_tensor , group = group )
32+ else :
33+ global_input_tensor_numels = [
34+ torch .tensor ([reduce (mul , shape ) for shape in param_shapes ], dtype = torch .int64 , device = "cpu" )
35+ for param_shapes in params_shapes_across_group # each param_shapes represents all params shapes on one rank
36+ ]
37+
2238 flatten_copyin_tensor = torch .empty ((sum (input_tensor_numels ),), dtype = param0 .dtype , device = param0 .device )
2339 splits_copyin_tensor = torch .split (flatten_copyin_tensor , input_tensor_numels )
2440 torch ._foreach_copy_ (splits_copyin_tensor , [p .flatten () for p in params ])
2541
26- input_tensor_numels_tensor = torch .tensor (input_tensor_numels , dtype = torch .int64 , device = param0 .device )
27- global_input_tensor_numels = [
28- torch .zeros_like (input_tensor_numels_tensor ) for _ in range (dist .get_world_size (group ))
29- ]
30-
31- dist .all_gather (global_input_tensor_numels , input_tensor_numels_tensor , group = group )
3242 copyout_size = int (sum (sum (i ) for i in global_input_tensor_numels ))
3343 flatten_copyout_tensor = torch .empty ((copyout_size ,), dtype = param0 .dtype , device = param0 .device )
3444
0 commit comments