@@ -31,33 +31,47 @@ def foreach_all_gather(
3131 dist .all_gather (global_input_tensor_numels , input_tensor_numels_tensor , group = group )
3232 else :
3333 global_input_tensor_numels = [
34- torch .tensor ([reduce (mul , shape ) for shape in param_shapes ], dtype = torch .int64 , device = "cpu" )
34+ torch .tensor ([reduce (mul , shape , 1 ) for shape in param_shapes ], dtype = torch .int64 , device = "cpu" )
3535 for param_shapes in params_shapes_across_group # each param_shapes represents all params shapes on one rank
3636 ]
3737
38- flatten_copyin_tensor = torch .empty ((sum (input_tensor_numels ),), dtype = param0 .dtype , device = param0 .device )
39- splits_copyin_tensor = torch .split (flatten_copyin_tensor , input_tensor_numels )
40- torch ._foreach_copy_ (splits_copyin_tensor , [p .flatten () for p in params ])
41-
42- copyout_size = int (sum (sum (i ) for i in global_input_tensor_numels ))
43- flatten_copyout_tensor = torch .empty ((copyout_size ,), dtype = param0 .dtype , device = param0 .device )
44-
45- dist .all_gather_into_tensor (flatten_copyout_tensor , flatten_copyin_tensor , group = group )
46- copyout_split_size : list [int ] = sum ([i .tolist () for i in global_input_tensor_numels ], [])
47- splits_copyout_tensor = torch .split (flatten_copyout_tensor , copyout_split_size )
48-
49- _global_input_tensor_shapes : list [None ] | list [list [tuple ]] = [None for _ in range (dist .get_world_size (group ))]
50- dist .all_gather_object (_global_input_tensor_shapes , input_tensor_shapes , group = group )
51- _global_input_tensor_shapes = cast (list [list [tuple ]], _global_input_tensor_shapes )
52- global_input_tensor_shapes : list [tuple ] = sum (_global_input_tensor_shapes , [])
53-
54- gathered_params : list [list [torch .Tensor ]] = []
55- for i in range (len (params )):
56- single_gathered_params : list [torch .Tensor ] = []
57- for rank in range (dist .get_world_size (group )):
58- offset = len (params ) * rank
59- origin_shape : tuple = global_input_tensor_shapes [offset + i ]
60- single_gathered_params .append (splits_copyout_tensor [offset + i ].view (origin_shape ))
61- gathered_params .append (single_gathered_params )
38+ if len (params ) == 1 :
39+ param0_shape_except_dim0 = list (param0 .shape )[1 :]
40+ param0_numel_except_dim0 = param0 [0 ].numel ()
41+ # Calculate the size of dimension 0 of the gathered tensor, it's compatible for the case of uneven split
42+ split_dim0_sizes = [t .tolist ()[0 ] // param0_numel_except_dim0 for t in global_input_tensor_numels ]
43+ gathered_tensor_dim0_size = sum (split_dim0_sizes )
44+
45+ # all_gather_into_tensor gather different ranks data along dimension 0
46+ gathered_tensor = torch .empty (
47+ (gathered_tensor_dim0_size , * param0_shape_except_dim0 ), dtype = param0 .dtype , device = param0 .device
48+ )
49+ dist .all_gather_into_tensor (gathered_tensor , param0 , group = group )
50+ return [gathered_tensor .split (split_dim0_sizes , dim = 0 )]
51+ else :
52+ flatten_copyin_tensor = torch .empty ((sum (input_tensor_numels ),), dtype = param0 .dtype , device = param0 .device )
53+ splits_copyin_tensor = torch .split (flatten_copyin_tensor , input_tensor_numels )
54+ torch ._foreach_copy_ (splits_copyin_tensor , [p .flatten () for p in params ])
55+
56+ copyout_size = int (sum (sum (i ) for i in global_input_tensor_numels ))
57+ flatten_copyout_tensor = torch .empty ((copyout_size ,), dtype = param0 .dtype , device = param0 .device )
58+
59+ dist .all_gather_into_tensor (flatten_copyout_tensor , flatten_copyin_tensor , group = group )
60+ copyout_split_size : list [int ] = sum ([i .tolist () for i in global_input_tensor_numels ], [])
61+ splits_copyout_tensor = torch .split (flatten_copyout_tensor , copyout_split_size )
62+
63+ _global_input_tensor_shapes : list [None ] | list [list [tuple ]] = [None for _ in range (dist .get_world_size (group ))]
64+ dist .all_gather_object (_global_input_tensor_shapes , input_tensor_shapes , group = group )
65+ _global_input_tensor_shapes = cast (list [list [tuple ]], _global_input_tensor_shapes )
66+ global_input_tensor_shapes : list [tuple ] = sum (_global_input_tensor_shapes , [])
67+
68+ gathered_params : list [list [torch .Tensor ]] = []
69+ for i in range (len (params )):
70+ single_gathered_params : list [torch .Tensor ] = []
71+ for rank in range (dist .get_world_size (group )):
72+ offset = len (params ) * rank
73+ origin_shape : tuple = global_input_tensor_shapes [offset + i ]
74+ single_gathered_params .append (splits_copyout_tensor [offset + i ].view (origin_shape ))
75+ gathered_params .append (single_gathered_params )
6276
6377 return gathered_params
0 commit comments