Skip to content

Commit d4b1160

Browse files
author
Lior Paz
committed
Replace inter node barrier
1 parent 69f4fbc commit d4b1160

File tree

4 files changed

+165
-123
lines changed

4 files changed

+165
-123
lines changed

src/team_lib/mhba/xccl_mhba_collective.c

Lines changed: 130 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ static xccl_status_t xccl_mhba_fanout_start(xccl_coll_task_t *task)
150150
/* start task if completion event received */
151151
task->state = XCCL_TASK_STATE_INPROGRESS;
152152

153-
/* Start fanin */
153+
/* Start fanout */
154154
if (XCCL_OK == xccl_mhba_node_fanout(team, request)) {
155155
task->state = XCCL_TASK_STATE_COMPLETED;
156156
xccl_mhba_debug("Algorithm completion");
@@ -176,12 +176,43 @@ static xccl_status_t xccl_mhba_fanout_progress(xccl_coll_task_t *task)
176176
return XCCL_OK;
177177
}
178178

179+
static inline xccl_status_t send_block_data(struct ibv_qp *qp,
180+
uint64_t src_addr,
181+
uint32_t msg_size, uint32_t lkey,
182+
uint64_t remote_addr, uint32_t rkey,
183+
int send_flags, int with_imm)
184+
{
185+
struct ibv_send_wr *bad_wr;
186+
struct ibv_sge list = {
187+
.addr = src_addr,
188+
.length = msg_size,
189+
.lkey = lkey,
190+
};
191+
192+
struct ibv_send_wr wr = {
193+
.wr_id = 1,
194+
.sg_list = &list,
195+
.num_sge = 1,
196+
.opcode = with_imm ? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE,
197+
.send_flags = send_flags,
198+
.wr.rdma.remote_addr = remote_addr,
199+
.wr.rdma.rkey = rkey,
200+
};
201+
202+
if (ibv_post_send(qp, &wr, &bad_wr)) {
203+
xccl_mhba_error("failed to post send");
204+
return XCCL_ERR_NO_MESSAGE;
205+
}
206+
return XCCL_OK;
207+
}
208+
179209
static xccl_status_t xccl_mhba_asr_barrier_start(xccl_coll_task_t *task)
180210
{
181211
xccl_mhba_task_t *self = ucs_derived_of(task, xccl_mhba_task_t);
182212
xccl_mhba_coll_req_t *request = self->req;
183213
xccl_mhba_team_t *team = request->team;
184214
xccl_mhba_debug("asr barrier start");
215+
int i;
185216

186217
// despite while statement, non blocking because have independent cq, will be finished in a finite time
187218
xccl_mhba_populate_send_recv_mkeys(team, request);
@@ -191,61 +222,28 @@ static xccl_status_t xccl_mhba_asr_barrier_start(xccl_coll_task_t *task)
191222
MHBA_CTRL_SIZE);
192223

193224
task->state = XCCL_TASK_STATE_INPROGRESS;
194-
xccl_coll_op_args_t coll = {
195-
.coll_type = XCCL_BARRIER,
196-
.alg.set_by_user = 0,
197-
};
198-
//todo create special barrier to support multiple parallel ops - with seq_id
199-
team->net.ucx_team->ctx->lib->collective_init(&coll, &request->barrier_req,
200-
team->net.ucx_team);
201-
team->net.ucx_team->ctx->lib->collective_post(request->barrier_req);
225+
226+
team->inter_node_barrier[team->net.sbgp->group_rank] = request->seq_num;
227+
for(i=0; i<team->net.net_size;i++){
228+
xccl_status_t status = send_block_data(team->net.qps[i], (uintptr_t)team->inter_node_barrier_mr->addr+team->net.sbgp->group_rank*sizeof(int) , sizeof(int),
229+
team->inter_node_barrier_mr->lkey,
230+
team->net.remote_ctrl[i].barrier_addr+sizeof(int)*team->net.sbgp->group_rank, team->net.remote_ctrl[i].barrier_rkey, 0, 0);
231+
if (status != XCCL_OK) {
232+
xccl_mhba_error("Failed sending block");
233+
return status;
234+
}
235+
}
202236
xccl_task_enqueue(task->schedule->tl_ctx->pq, task);
203237
return XCCL_OK;
204238
}
205239

206240
xccl_status_t xccl_mhba_asr_barrier_progress(xccl_coll_task_t *task)
207241
{
208-
xccl_mhba_task_t *self = ucs_derived_of(task, xccl_mhba_task_t);
209-
xccl_mhba_coll_req_t *request = self->req;
210-
xccl_mhba_team_t *team = request->team;
211-
212-
if (XCCL_OK ==
213-
team->net.ucx_team->ctx->lib->collective_test(request->barrier_req)) {
214-
team->net.ucx_team->ctx->lib->collective_finalize(request->barrier_req);
215-
task->state = XCCL_TASK_STATE_COMPLETED;
216-
}
242+
task->state = XCCL_TASK_STATE_COMPLETED;
217243
return XCCL_OK;
218244
}
219245

220-
static inline xccl_status_t send_block_data(struct ibv_qp *qp,
221-
uint64_t src_addr,
222-
uint32_t msg_size, uint32_t lkey,
223-
uint64_t remote_addr, uint32_t rkey,
224-
int send_flags, int with_imm)
225-
{
226-
struct ibv_send_wr *bad_wr;
227-
struct ibv_sge list = {
228-
.addr = src_addr,
229-
.length = msg_size,
230-
.lkey = lkey,
231-
};
232-
233-
struct ibv_send_wr wr = {
234-
.wr_id = 1,
235-
.sg_list = &list,
236-
.num_sge = 1,
237-
.opcode = with_imm ? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE,
238-
.send_flags = send_flags,
239-
.wr.rdma.remote_addr = remote_addr,
240-
.wr.rdma.rkey = rkey,
241-
};
242246

243-
if (ibv_post_send(qp, &wr, &bad_wr)) {
244-
xccl_mhba_error("failed to post send");
245-
return XCCL_ERR_NO_MESSAGE;
246-
}
247-
return XCCL_OK;
248-
}
249247

250248
static inline xccl_status_t send_atomic(struct ibv_qp *qp, uint64_t remote_addr,
251249
uint32_t rkey, xccl_mhba_team_t *team,
@@ -332,70 +330,78 @@ xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task)
332330
int i, j, k, dest_rank, rank, n_compl, ret;
333331
uint64_t src_addr, remote_addr;
334332
struct ibv_wc transpose_completion[1];
333+
int counter = 0;
335334
xccl_status_t status;
336335

337336
xccl_mhba_debug("send blocks start");
338337
task->state = XCCL_TASK_STATE_INPROGRESS;
339338
rank = team->net.rank_map[team->net.sbgp->group_rank];
340339

341-
for (i = 0; i < net_size; i++) {
342-
dest_rank = team->net.rank_map[i];
343-
//send all blocks from curr node to some ARR
344-
for (j = 0; j < xccl_round_up(node_size, block_size); j++) {
345-
for (k = 0; k < xccl_round_up(node_size, block_size); k++) {
346-
src_addr = (uintptr_t)(op_msgsize * index + node_msgsize * dest_rank +
347-
col_msgsize * j + block_msgsize * k);
348-
remote_addr = (uintptr_t)(op_msgsize * index + node_msgsize * rank +
349-
block_msgsize * j + col_msgsize * k);
350-
prepost_dummy_recv(team->node.umr_qp, 1);
351-
// SW Transpose
352-
status = send_block_data(
353-
team->node.umr_qp, src_addr, block_msgsize,
354-
team->node.team_send_mkey->lkey,
355-
(uintptr_t)request->transpose_buf_mr->addr,
356-
request->transpose_buf_mr->rkey, IBV_SEND_SIGNALED, 1);
357-
if (status != XCCL_OK) {
358-
xccl_mhba_error(
359-
"Failed sending block to transpose buffer[%d,%d,%d]", i, j, k);
360-
return status;
361-
}
362-
n_compl = 0;
363-
while (n_compl != 2) {
364-
ret = ibv_poll_cq(team->node.umr_cq, 1, transpose_completion);
365-
if (ret > 0) {
366-
if (transpose_completion[0].status != IBV_WC_SUCCESS) {
340+
while(counter < net_size) {
341+
for (i = 0; i < net_size; i++) {
342+
// printf("i %d is %d\n",i,team->inter_node_barrier[i]);
343+
if (team->inter_node_barrier[i] == request->seq_num && !team->inter_node_barrier_flag[i]) {
344+
// printf("req %d node %d\n",request->seq_num,i);
345+
team->inter_node_barrier_flag[i] = 1;
346+
dest_rank = team->net.rank_map[i];
347+
//send all blocks from curr node to some ARR
348+
for (j = 0; j < xccl_round_up(node_size, block_size); j++) {
349+
for (k = 0; k < xccl_round_up(node_size, block_size); k++) {
350+
src_addr = (uintptr_t)(op_msgsize * index + node_msgsize * dest_rank +
351+
col_msgsize * j + block_msgsize * k);
352+
remote_addr = (uintptr_t)(op_msgsize * index + node_msgsize * rank +
353+
block_msgsize * j + col_msgsize * k);
354+
prepost_dummy_recv(team->node.umr_qp, 1);
355+
// SW Transpose
356+
status = send_block_data(
357+
team->node.umr_qp, src_addr, block_msgsize,
358+
team->node.team_send_mkey->lkey,
359+
(uintptr_t) request->transpose_buf_mr->addr,
360+
request->transpose_buf_mr->rkey, IBV_SEND_SIGNALED, 1);
361+
if (status != XCCL_OK) {
367362
xccl_mhba_error(
368-
"local copy for transpose CQ returned "
369-
"completion with status %s (%d)",
370-
ibv_wc_status_str(transpose_completion[0].status),
371-
transpose_completion[0].status);
372-
return XCCL_ERR_NO_MESSAGE;
363+
"Failed sending block to transpose buffer[%d,%d,%d]", i, j, k);
364+
return status;
365+
}
366+
n_compl = 0;
367+
while (n_compl != 2) {
368+
ret = ibv_poll_cq(team->node.umr_cq, 1, transpose_completion);
369+
if (ret > 0) {
370+
if (transpose_completion[0].status != IBV_WC_SUCCESS) {
371+
xccl_mhba_error(
372+
"local copy for transpose CQ returned "
373+
"completion with status %s (%d)",
374+
ibv_wc_status_str(transpose_completion[0].status),
375+
transpose_completion[0].status);
376+
return XCCL_ERR_NO_MESSAGE;
377+
}
378+
n_compl++;
379+
}
373380
}
374-
n_compl++;
381+
transpose_square_mat(request->transpose_buf_mr->addr,
382+
block_size, request->args.buffer_info.len,
383+
request->tmp_transpose_buf);
384+
status = send_block_data(
385+
team->net.qps[i],
386+
(uintptr_t) request->transpose_buf_mr->addr, block_msgsize,
387+
request->transpose_buf_mr->lkey, remote_addr,
388+
team->net.rkeys[i], IBV_SEND_SIGNALED, 0);
389+
if (status != XCCL_OK) {
390+
xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k);
391+
return status;
392+
}
393+
while (!ibv_poll_cq(team->net.cq, 1, transpose_completion)) {}
375394
}
376395
}
377-
transpose_square_mat(request->transpose_buf_mr->addr,
378-
block_size, request->args.buffer_info.len,
379-
request->tmp_transpose_buf);
380-
status = send_block_data(
381-
team->net.qps[i],
382-
(uintptr_t)request->transpose_buf_mr->addr, block_msgsize,
383-
request->transpose_buf_mr->lkey, remote_addr,
384-
team->net.rkeys[i], IBV_SEND_SIGNALED, 0);
385-
if (status != XCCL_OK) {
386-
xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k);
387-
return status;
388-
}
389-
while (!ibv_poll_cq(team->net.cq, 1, transpose_completion)) {}
396+
counter += 1;
390397
}
391398
}
392399
}
393-
394400
for (i = 0; i < net_size; i++) {
395401
status = send_atomic(team->net.qps[i],
396-
(uintptr_t)team->net.remote_ctrl[i].addr +
402+
(uintptr_t)team->net.remote_ctrl[i].ctrl_addr +
397403
(index * MHBA_CTRL_SIZE),
398-
team->net.remote_ctrl[i].rkey, team, request);
404+
team->net.remote_ctrl[i].ctrl_rkey, team, request);
399405
if (status != XCCL_OK) {
400406
xccl_mhba_error("Failed sending atomic to node [%d]", i);
401407
return status;
@@ -421,40 +427,47 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task)
421427
int col_msgsize = len * block_size * node_size;
422428
int block_msgsize = SQUARED(block_size) * len;
423429
int i, j, k, dest_rank, rank;
430+
int counter = 0;
424431
uint64_t src_addr, remote_addr;
425432
xccl_status_t status;
426433

427434
xccl_mhba_debug("send blocks start");
428435
task->state = XCCL_TASK_STATE_INPROGRESS;
429436
rank = team->net.rank_map[team->net.sbgp->group_rank];
430437

431-
for (i = 0; i < net_size; i++) {
432-
dest_rank = team->net.rank_map[i];
433-
//send all blocks from curr node to some ARR
434-
for (j = 0; j < xccl_round_up(node_size, block_size); j++) {
435-
for (k = 0; k < xccl_round_up(node_size, block_size); k++) {
436-
src_addr = (uintptr_t)(op_msgsize * index + node_msgsize * dest_rank +
437-
col_msgsize * j + block_msgsize * k);
438-
remote_addr = (uintptr_t)(op_msgsize * index + node_msgsize * rank +
439-
block_msgsize * j + col_msgsize * k);
440-
441-
status = send_block_data(team->net.qps[i], src_addr, block_msgsize,
442-
team->node.team_send_mkey->lkey,
443-
remote_addr, team->net.rkeys[i], 0, 0);
438+
while(counter < net_size) {
439+
for (i = 0; i < net_size; i++) {
440+
if (team->inter_node_barrier[i] == request->seq_num && !team->inter_node_barrier_flag[i]) {
441+
team->inter_node_barrier_flag[i] = 1;
442+
dest_rank = team->net.rank_map[i];
443+
//send all blocks from curr node to some ARR
444+
for (j = 0; j < xccl_round_up(node_size, block_size); j++) {
445+
for (k = 0; k < xccl_round_up(node_size, block_size); k++) {
446+
src_addr = (uintptr_t)(op_msgsize * index + node_msgsize * dest_rank +
447+
col_msgsize * j + block_msgsize * k);
448+
remote_addr = (uintptr_t)(op_msgsize * index + node_msgsize * rank +
449+
block_msgsize * j + col_msgsize * k);
450+
451+
status = send_block_data(team->net.qps[i], src_addr, block_msgsize,
452+
team->node.team_send_mkey->lkey,
453+
remote_addr, team->net.rkeys[i], 0, 0);
454+
if (status != XCCL_OK) {
455+
xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k);
456+
return status;
457+
}
458+
}
459+
}
460+
status = send_atomic(team->net.qps[i],
461+
(uintptr_t) team->net.remote_ctrl[i].ctrl_addr +
462+
(index * MHBA_CTRL_SIZE),
463+
team->net.remote_ctrl[i].ctrl_rkey, team, request);
444464
if (status != XCCL_OK) {
445-
xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k);
465+
xccl_mhba_error("Failed sending atomic to node [%d]", i);
446466
return status;
447467
}
468+
counter += 1;
448469
}
449470
}
450-
status = send_atomic(team->net.qps[i],
451-
(uintptr_t)team->net.remote_ctrl[i].addr +
452-
(index * MHBA_CTRL_SIZE),
453-
team->net.remote_ctrl[i].rkey, team, request);
454-
if (status != XCCL_OK) {
455-
xccl_mhba_error("Failed sending atomic to node [%d]", i);
456-
return status;
457-
}
458471
}
459472
xccl_task_enqueue(task->schedule->tl_ctx->pq, task);
460473
return XCCL_OK;
@@ -573,6 +586,7 @@ xccl_status_t xccl_mhba_alltoall_init(xccl_coll_op_args_t *coll_args,
573586
xccl_mhba_fanout_start;
574587
request->tasks[1].super.progress = xccl_mhba_fanout_progress;
575588
} else {
589+
memset(team->inter_node_barrier_flag,0,sizeof(int)*team->net.net_size);
576590
request->tasks[1].super.handlers[XCCL_EVENT_COMPLETED] =
577591
xccl_mhba_asr_barrier_start;
578592
request->tasks[1].super.progress = xccl_mhba_asr_barrier_progress;

src/team_lib/mhba/xccl_mhba_lib.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ static ucs_config_field_t xccl_tl_mhba_context_config_table[] = {
2424
ucs_offsetof(xccl_tl_mhba_context_config_t, devices),
2525
UCS_CONFIG_TYPE_STRING_ARRAY},
2626

27-
{"TRANSPOSE", "1", "Boolean - with transpose or not",
27+
{"TRANSPOSE", "0", "Boolean - with transpose or not",
2828
ucs_offsetof(xccl_tl_mhba_context_config_t, transpose),
2929
UCS_CONFIG_TYPE_UINT},
3030

31-
{"TRANSPOSE_HW_LIMITATIONS", "0",
31+
{"TRANSPOSE_HW_LIMITATIONS", "1",
3232
"Boolean - with transpose hw limitations or not",
3333
ucs_offsetof(xccl_tl_mhba_context_config_t, transpose_hw_limitations),
3434
UCS_CONFIG_TYPE_UINT}, //todo change to 1 in production

src/team_lib/mhba/xccl_mhba_lib.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,10 @@ typedef struct xccl_mhba_net {
129129
struct ibv_cq *cq;
130130
struct ibv_mr *ctrl_mr;
131131
struct {
132-
void *addr;
133-
uint32_t rkey;
132+
void *ctrl_addr;
133+
uint32_t ctrl_rkey;
134+
void *barrier_addr;
135+
uint32_t barrier_rkey;
134136
} * remote_ctrl;
135137
uint32_t *rkeys;
136138
xccl_tl_team_t *ucx_team;
@@ -143,6 +145,10 @@ typedef struct xccl_mhba_team {
143145
uint64_t max_msg_size;
144146
xccl_mhba_node_t node;
145147
xccl_mhba_net_t net;
148+
int* inter_node_barrier;
149+
int* inter_node_barrier_flag;
150+
struct ibv_mr *inter_node_barrier_mr;
151+
struct ibv_mr **net_barrier_mr;
146152
int sequence_number;
147153
int op_busy[MAX_OUTSTANDING_OPS];
148154
int cq_completions[MAX_OUTSTANDING_OPS];

0 commit comments

Comments
 (0)