@@ -150,7 +150,7 @@ static xccl_status_t xccl_mhba_fanout_start(xccl_coll_task_t *task)
150150 /* start task if completion event received */
151151 task -> state = XCCL_TASK_STATE_INPROGRESS ;
152152
153- /* Start fanin */
153+ /* Start fanout */
154154 if (XCCL_OK == xccl_mhba_node_fanout (team , request )) {
155155 task -> state = XCCL_TASK_STATE_COMPLETED ;
156156 xccl_mhba_debug ("Algorithm completion" );
@@ -176,12 +176,43 @@ static xccl_status_t xccl_mhba_fanout_progress(xccl_coll_task_t *task)
176176 return XCCL_OK ;
177177}
178178
179+ static inline xccl_status_t send_block_data (struct ibv_qp * qp ,
180+ uint64_t src_addr ,
181+ uint32_t msg_size , uint32_t lkey ,
182+ uint64_t remote_addr , uint32_t rkey ,
183+ int send_flags , int with_imm )
184+ {
185+ struct ibv_send_wr * bad_wr ;
186+ struct ibv_sge list = {
187+ .addr = src_addr ,
188+ .length = msg_size ,
189+ .lkey = lkey ,
190+ };
191+
192+ struct ibv_send_wr wr = {
193+ .wr_id = 1 ,
194+ .sg_list = & list ,
195+ .num_sge = 1 ,
196+ .opcode = with_imm ? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE ,
197+ .send_flags = send_flags ,
198+ .wr .rdma .remote_addr = remote_addr ,
199+ .wr .rdma .rkey = rkey ,
200+ };
201+
202+ if (ibv_post_send (qp , & wr , & bad_wr )) {
203+ xccl_mhba_error ("failed to post send" );
204+ return XCCL_ERR_NO_MESSAGE ;
205+ }
206+ return XCCL_OK ;
207+ }
208+
179209static xccl_status_t xccl_mhba_asr_barrier_start (xccl_coll_task_t * task )
180210{
181211 xccl_mhba_task_t * self = ucs_derived_of (task , xccl_mhba_task_t );
182212 xccl_mhba_coll_req_t * request = self -> req ;
183213 xccl_mhba_team_t * team = request -> team ;
184214 xccl_mhba_debug ("asr barrier start" );
215+ int i ;
185216
186217 // despite while statement, non blocking because have independent cq, will be finished in a finite time
187218 xccl_mhba_populate_send_recv_mkeys (team , request );
@@ -191,61 +222,28 @@ static xccl_status_t xccl_mhba_asr_barrier_start(xccl_coll_task_t *task)
191222 MHBA_CTRL_SIZE );
192223
193224 task -> state = XCCL_TASK_STATE_INPROGRESS ;
194- xccl_coll_op_args_t coll = {
195- .coll_type = XCCL_BARRIER ,
196- .alg .set_by_user = 0 ,
197- };
198- //todo create special barrier to support multiple parallel ops - with seq_id
199- team -> net .ucx_team -> ctx -> lib -> collective_init (& coll , & request -> barrier_req ,
200- team -> net .ucx_team );
201- team -> net .ucx_team -> ctx -> lib -> collective_post (request -> barrier_req );
225+
226+ team -> inter_node_barrier [team -> net .sbgp -> group_rank ] = request -> seq_num ;
227+ for (i = 0 ; i < team -> net .net_size ;i ++ ){
228+ xccl_status_t status = send_block_data (team -> net .qps [i ], (uintptr_t )team -> inter_node_barrier_mr -> addr + team -> net .sbgp -> group_rank * sizeof (int ) , sizeof (int ),
229+ team -> inter_node_barrier_mr -> lkey ,
230+ team -> net .remote_ctrl [i ].barrier_addr + sizeof (int )* team -> net .sbgp -> group_rank , team -> net .remote_ctrl [i ].barrier_rkey , 0 , 0 );
231+ if (status != XCCL_OK ) {
232+ xccl_mhba_error ("Failed sending block" );
233+ return status ;
234+ }
235+ }
202236 xccl_task_enqueue (task -> schedule -> tl_ctx -> pq , task );
203237 return XCCL_OK ;
204238}
205239
206240xccl_status_t xccl_mhba_asr_barrier_progress (xccl_coll_task_t * task )
207241{
208- xccl_mhba_task_t * self = ucs_derived_of (task , xccl_mhba_task_t );
209- xccl_mhba_coll_req_t * request = self -> req ;
210- xccl_mhba_team_t * team = request -> team ;
211-
212- if (XCCL_OK ==
213- team -> net .ucx_team -> ctx -> lib -> collective_test (request -> barrier_req )) {
214- team -> net .ucx_team -> ctx -> lib -> collective_finalize (request -> barrier_req );
215- task -> state = XCCL_TASK_STATE_COMPLETED ;
216- }
242+ task -> state = XCCL_TASK_STATE_COMPLETED ;
217243 return XCCL_OK ;
218244}
219245
220- static inline xccl_status_t send_block_data (struct ibv_qp * qp ,
221- uint64_t src_addr ,
222- uint32_t msg_size , uint32_t lkey ,
223- uint64_t remote_addr , uint32_t rkey ,
224- int send_flags , int with_imm )
225- {
226- struct ibv_send_wr * bad_wr ;
227- struct ibv_sge list = {
228- .addr = src_addr ,
229- .length = msg_size ,
230- .lkey = lkey ,
231- };
232-
233- struct ibv_send_wr wr = {
234- .wr_id = 1 ,
235- .sg_list = & list ,
236- .num_sge = 1 ,
237- .opcode = with_imm ? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE ,
238- .send_flags = send_flags ,
239- .wr .rdma .remote_addr = remote_addr ,
240- .wr .rdma .rkey = rkey ,
241- };
242246
243- if (ibv_post_send (qp , & wr , & bad_wr )) {
244- xccl_mhba_error ("failed to post send" );
245- return XCCL_ERR_NO_MESSAGE ;
246- }
247- return XCCL_OK ;
248- }
249247
250248static inline xccl_status_t send_atomic (struct ibv_qp * qp , uint64_t remote_addr ,
251249 uint32_t rkey , xccl_mhba_team_t * team ,
@@ -332,70 +330,78 @@ xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task)
332330 int i , j , k , dest_rank , rank , n_compl , ret ;
333331 uint64_t src_addr , remote_addr ;
334332 struct ibv_wc transpose_completion [1 ];
333+ int counter = 0 ;
335334 xccl_status_t status ;
336335
337336 xccl_mhba_debug ("send blocks start" );
338337 task -> state = XCCL_TASK_STATE_INPROGRESS ;
339338 rank = team -> net .rank_map [team -> net .sbgp -> group_rank ];
340339
341- for (i = 0 ; i < net_size ; i ++ ) {
342- dest_rank = team -> net .rank_map [i ];
343- //send all blocks from curr node to some ARR
344- for (j = 0 ; j < xccl_round_up (node_size , block_size ); j ++ ) {
345- for (k = 0 ; k < xccl_round_up (node_size , block_size ); k ++ ) {
346- src_addr = (uintptr_t )(op_msgsize * index + node_msgsize * dest_rank +
347- col_msgsize * j + block_msgsize * k );
348- remote_addr = (uintptr_t )(op_msgsize * index + node_msgsize * rank +
349- block_msgsize * j + col_msgsize * k );
350- prepost_dummy_recv (team -> node .umr_qp , 1 );
351- // SW Transpose
352- status = send_block_data (
353- team -> node .umr_qp , src_addr , block_msgsize ,
354- team -> node .team_send_mkey -> lkey ,
355- (uintptr_t )request -> transpose_buf_mr -> addr ,
356- request -> transpose_buf_mr -> rkey , IBV_SEND_SIGNALED , 1 );
357- if (status != XCCL_OK ) {
358- xccl_mhba_error (
359- "Failed sending block to transpose buffer[%d,%d,%d]" , i , j , k );
360- return status ;
361- }
362- n_compl = 0 ;
363- while (n_compl != 2 ) {
364- ret = ibv_poll_cq (team -> node .umr_cq , 1 , transpose_completion );
365- if (ret > 0 ) {
366- if (transpose_completion [0 ].status != IBV_WC_SUCCESS ) {
340+ while (counter < net_size ) {
341+ for (i = 0 ; i < net_size ; i ++ ) {
342+ // printf("i %d is %d\n",i,team->inter_node_barrier[i]);
343+ if (team -> inter_node_barrier [i ] == request -> seq_num && !team -> inter_node_barrier_flag [i ]) {
344+ // printf("req %d node %d\n",request->seq_num,i);
345+ team -> inter_node_barrier_flag [i ] = 1 ;
346+ dest_rank = team -> net .rank_map [i ];
347+ //send all blocks from curr node to some ARR
348+ for (j = 0 ; j < xccl_round_up (node_size , block_size ); j ++ ) {
349+ for (k = 0 ; k < xccl_round_up (node_size , block_size ); k ++ ) {
350+ src_addr = (uintptr_t )(op_msgsize * index + node_msgsize * dest_rank +
351+ col_msgsize * j + block_msgsize * k );
352+ remote_addr = (uintptr_t )(op_msgsize * index + node_msgsize * rank +
353+ block_msgsize * j + col_msgsize * k );
354+ prepost_dummy_recv (team -> node .umr_qp , 1 );
355+ // SW Transpose
356+ status = send_block_data (
357+ team -> node .umr_qp , src_addr , block_msgsize ,
358+ team -> node .team_send_mkey -> lkey ,
359+ (uintptr_t ) request -> transpose_buf_mr -> addr ,
360+ request -> transpose_buf_mr -> rkey , IBV_SEND_SIGNALED , 1 );
361+ if (status != XCCL_OK ) {
367362 xccl_mhba_error (
368- "local copy for transpose CQ returned "
369- "completion with status %s (%d)" ,
370- ibv_wc_status_str (transpose_completion [0 ].status ),
371- transpose_completion [0 ].status );
372- return XCCL_ERR_NO_MESSAGE ;
363+ "Failed sending block to transpose buffer[%d,%d,%d]" , i , j , k );
364+ return status ;
365+ }
366+ n_compl = 0 ;
367+ while (n_compl != 2 ) {
368+ ret = ibv_poll_cq (team -> node .umr_cq , 1 , transpose_completion );
369+ if (ret > 0 ) {
370+ if (transpose_completion [0 ].status != IBV_WC_SUCCESS ) {
371+ xccl_mhba_error (
372+ "local copy for transpose CQ returned "
373+ "completion with status %s (%d)" ,
374+ ibv_wc_status_str (transpose_completion [0 ].status ),
375+ transpose_completion [0 ].status );
376+ return XCCL_ERR_NO_MESSAGE ;
377+ }
378+ n_compl ++ ;
379+ }
373380 }
374- n_compl ++ ;
381+ transpose_square_mat (request -> transpose_buf_mr -> addr ,
382+ block_size , request -> args .buffer_info .len ,
383+ request -> tmp_transpose_buf );
384+ status = send_block_data (
385+ team -> net .qps [i ],
386+ (uintptr_t ) request -> transpose_buf_mr -> addr , block_msgsize ,
387+ request -> transpose_buf_mr -> lkey , remote_addr ,
388+ team -> net .rkeys [i ], IBV_SEND_SIGNALED , 0 );
389+ if (status != XCCL_OK ) {
390+ xccl_mhba_error ("Failed sending block [%d,%d,%d]" , i , j , k );
391+ return status ;
392+ }
393+ while (!ibv_poll_cq (team -> net .cq , 1 , transpose_completion )) {}
375394 }
376395 }
377- transpose_square_mat (request -> transpose_buf_mr -> addr ,
378- block_size , request -> args .buffer_info .len ,
379- request -> tmp_transpose_buf );
380- status = send_block_data (
381- team -> net .qps [i ],
382- (uintptr_t )request -> transpose_buf_mr -> addr , block_msgsize ,
383- request -> transpose_buf_mr -> lkey , remote_addr ,
384- team -> net .rkeys [i ], IBV_SEND_SIGNALED , 0 );
385- if (status != XCCL_OK ) {
386- xccl_mhba_error ("Failed sending block [%d,%d,%d]" , i , j , k );
387- return status ;
388- }
389- while (!ibv_poll_cq (team -> net .cq , 1 , transpose_completion )) {}
396+ counter += 1 ;
390397 }
391398 }
392399 }
393-
394400 for (i = 0 ; i < net_size ; i ++ ) {
395401 status = send_atomic (team -> net .qps [i ],
396- (uintptr_t )team -> net .remote_ctrl [i ].addr +
402+ (uintptr_t )team -> net .remote_ctrl [i ].ctrl_addr +
397403 (index * MHBA_CTRL_SIZE ),
398- team -> net .remote_ctrl [i ].rkey , team , request );
404+ team -> net .remote_ctrl [i ].ctrl_rkey , team , request );
399405 if (status != XCCL_OK ) {
400406 xccl_mhba_error ("Failed sending atomic to node [%d]" , i );
401407 return status ;
@@ -421,40 +427,47 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task)
421427 int col_msgsize = len * block_size * node_size ;
422428 int block_msgsize = SQUARED (block_size ) * len ;
423429 int i , j , k , dest_rank , rank ;
430+ int counter = 0 ;
424431 uint64_t src_addr , remote_addr ;
425432 xccl_status_t status ;
426433
427434 xccl_mhba_debug ("send blocks start" );
428435 task -> state = XCCL_TASK_STATE_INPROGRESS ;
429436 rank = team -> net .rank_map [team -> net .sbgp -> group_rank ];
430437
431- for (i = 0 ; i < net_size ; i ++ ) {
432- dest_rank = team -> net .rank_map [i ];
433- //send all blocks from curr node to some ARR
434- for (j = 0 ; j < xccl_round_up (node_size , block_size ); j ++ ) {
435- for (k = 0 ; k < xccl_round_up (node_size , block_size ); k ++ ) {
436- src_addr = (uintptr_t )(op_msgsize * index + node_msgsize * dest_rank +
437- col_msgsize * j + block_msgsize * k );
438- remote_addr = (uintptr_t )(op_msgsize * index + node_msgsize * rank +
439- block_msgsize * j + col_msgsize * k );
440-
441- status = send_block_data (team -> net .qps [i ], src_addr , block_msgsize ,
442- team -> node .team_send_mkey -> lkey ,
443- remote_addr , team -> net .rkeys [i ], 0 , 0 );
438+ while (counter < net_size ) {
439+ for (i = 0 ; i < net_size ; i ++ ) {
440+ if (team -> inter_node_barrier [i ] == request -> seq_num && !team -> inter_node_barrier_flag [i ]) {
441+ team -> inter_node_barrier_flag [i ] = 1 ;
442+ dest_rank = team -> net .rank_map [i ];
443+ //send all blocks from curr node to some ARR
444+ for (j = 0 ; j < xccl_round_up (node_size , block_size ); j ++ ) {
445+ for (k = 0 ; k < xccl_round_up (node_size , block_size ); k ++ ) {
446+ src_addr = (uintptr_t )(op_msgsize * index + node_msgsize * dest_rank +
447+ col_msgsize * j + block_msgsize * k );
448+ remote_addr = (uintptr_t )(op_msgsize * index + node_msgsize * rank +
449+ block_msgsize * j + col_msgsize * k );
450+
451+ status = send_block_data (team -> net .qps [i ], src_addr , block_msgsize ,
452+ team -> node .team_send_mkey -> lkey ,
453+ remote_addr , team -> net .rkeys [i ], 0 , 0 );
454+ if (status != XCCL_OK ) {
455+ xccl_mhba_error ("Failed sending block [%d,%d,%d]" , i , j , k );
456+ return status ;
457+ }
458+ }
459+ }
460+ status = send_atomic (team -> net .qps [i ],
461+ (uintptr_t ) team -> net .remote_ctrl [i ].ctrl_addr +
462+ (index * MHBA_CTRL_SIZE ),
463+ team -> net .remote_ctrl [i ].ctrl_rkey , team , request );
444464 if (status != XCCL_OK ) {
445- xccl_mhba_error ("Failed sending block [%d,%d,%d ]" , i , j , k );
465+ xccl_mhba_error ("Failed sending atomic to node [%d]" , i );
446466 return status ;
447467 }
468+ counter += 1 ;
448469 }
449470 }
450- status = send_atomic (team -> net .qps [i ],
451- (uintptr_t )team -> net .remote_ctrl [i ].addr +
452- (index * MHBA_CTRL_SIZE ),
453- team -> net .remote_ctrl [i ].rkey , team , request );
454- if (status != XCCL_OK ) {
455- xccl_mhba_error ("Failed sending atomic to node [%d]" , i );
456- return status ;
457- }
458471 }
459472 xccl_task_enqueue (task -> schedule -> tl_ctx -> pq , task );
460473 return XCCL_OK ;
@@ -573,6 +586,7 @@ xccl_status_t xccl_mhba_alltoall_init(xccl_coll_op_args_t *coll_args,
573586 xccl_mhba_fanout_start ;
574587 request -> tasks [1 ].super .progress = xccl_mhba_fanout_progress ;
575588 } else {
589+ memset (team -> inter_node_barrier_flag ,0 ,sizeof (int )* team -> net .net_size );
576590 request -> tasks [1 ].super .handlers [XCCL_EVENT_COMPLETED ] =
577591 xccl_mhba_asr_barrier_start ;
578592 request -> tasks [1 ].super .progress = xccl_mhba_asr_barrier_progress ;
0 commit comments