Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wenduwan committed Aug 22, 2023
1 parent 84371ec commit 6260a4e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 59 deletions.
103 changes: 46 additions & 57 deletions ompi/mca/coll/base/coll_base_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -1245,113 +1245,102 @@ int ompi_coll_base_allreduce_intra_redscat_allgather(
return err;
}

/* copied function (with appropriate renaming) ends here */

int ompi_coll_base_allreduce_intra_allgatherreduce(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
int err = MPI_SUCCESS;
char *send_buf = sbuf;
int comm_size = ompi_comm_size(comm);
int err = MPI_SUCCESS;
int rank = ompi_comm_rank(comm);
bool commutative = ompi_op_is_commute(op);
ompi_request_t **reqs;

char *send_buf = sbuf;
/* in place handling */
if (sbuf == MPI_IN_PLACE) {
send_buf = rbuf;
} else {
/* copy over our elements into the final accumulation buffer */
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf);
}

/* we will need a temporary receive buffer that is large enough to fit everything */
/* Allocate a large-enough buffer to receive from everyone else */
char *tmp_buf = NULL, *tmp_buf_raw = NULL, *tmp_recv = NULL;
ptrdiff_t lb, extent, dsize, gap = 0;
ompi_datatype_get_extent(dtype, &lb, &extent);
dsize = opal_datatype_span(&dtype->super, count * comm_size, &gap);

/* Temporary buffer for receiving messages */
char *tmp_buf = NULL;
char *tmp_buf_raw = (char *) malloc(dsize);
if (NULL == tmp_buf_raw)
tmp_buf_raw = (char *) malloc(dsize);
if (NULL == tmp_buf_raw) {
return OMPI_ERR_OUT_OF_RESOURCE;
}

if (commutative) {
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf);
}

tmp_buf = tmp_buf_raw - gap;

/* Requests for send to AND receive from everyone else */
int reqs_needed = (comm_size - 1) * 2;
reqs = ompi_coll_base_comm_get_reqs(module->base_data, reqs_needed);

/* We need to send our data to all other peers in our comm */

ptrdiff_t incr = extent * (ptrdiff_t) count;
char *tmp_recv = (char *) tmp_buf;
int line;
tmp_recv = (char *) tmp_buf;

int peer_rank;
int req_index = 0;
for (int index = 0; index < comm_size; index++) {
peer_rank = index;
/* don't send to myself */
for (int peer_rank = 0; peer_rank < comm_size; peer_rank++) {
tmp_recv = tmp_buf + (peer_rank * incr);
if (peer_rank == rank) {
memcpy(tmp_recv, send_buf, incr);
continue;
}

tmp_recv = tmp_buf + (peer_rank * incr);
err = MCA_PML_CALL(irecv(tmp_recv, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
comm, &reqs[req_index]));
comm, &reqs[req_index++]));
if (MPI_SUCCESS != err) {
line = __LINE__;
goto err_hndl;
}

req_index += 1;
}

for (int index = 0; index < comm_size; index++) {
peer_rank = (rank + index) % comm_size;

if (peer_rank == rank) {
continue;
}

err = MCA_PML_CALL(isend(send_buf, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
MCA_PML_BASE_SEND_STANDARD, comm, &reqs[req_index]));
MCA_PML_BASE_SEND_STANDARD, comm, &reqs[req_index++]));
if (MPI_SUCCESS != err) {
goto err_hndl;
}
req_index += 1;
}

err = ompi_request_wait_all(req_index, reqs, MPI_STATUSES_IGNORE);
char *inbuf;

char *inbuf;
for (int peer_rank = 0; peer_rank < comm_size; peer_rank++) {
if (peer_rank != rank) {
inbuf = tmp_buf + (peer_rank * incr);
ompi_op_reduce(op, (void *) inbuf, rbuf, count, dtype);
inbuf = tmp_buf + (peer_rank * incr);
if (0 == peer_rank && !commutative) {
/* Sort the data buffer for non-commutative operations */
memcpy(rbuf, inbuf, incr);
continue;
}
ompi_op_reduce(op, (void *) inbuf, rbuf, count, dtype);
}

err_hndl:
if (NULL != tmp_buf_raw)
free(tmp_buf_raw);

if (MPI_SUCCESS != err) { /* Free the reqs */
/* first find the real error code */
/*
for( preq = reqs; preq < reqs+i; preq++ ) {
if (MPI_REQUEST_NULL == *preq) continue;
if (MPI_ERR_PENDING == (*preq)->req_status.MPI_ERROR) continue;
if ((*preq)->req_status.MPI_ERROR != MPI_SUCCESS) {
err = (*preq)->req_status.MPI_ERROR;
break;
if (NULL != reqs) {
if (MPI_ERR_IN_STATUS == err) {
for (int i = 0; i < reqs_needed; i++) {
if (MPI_REQUEST_NULL == reqs[i])
continue;
if (MPI_ERR_PENDING == reqs[i]->req_status.MPI_ERROR)
continue;
if (MPI_SUCCESS != reqs[i]->req_status.MPI_ERROR) {
err = reqs[i]->req_status.MPI_ERROR;
break;
}
}
}
}
*/
ompi_coll_base_free_reqs(reqs, reqs_needed);
}

/* All done */
return err;
}

/* copied function (with appropriate renaming) ends here */
2 changes: 1 addition & 1 deletion ompi/mca/coll/base/coll_base_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ int ompi_coll_base_allreduce_intra_ring(ALLREDUCE_ARGS);
int ompi_coll_base_allreduce_intra_ring_segmented(ALLREDUCE_ARGS, uint32_t segsize);
int ompi_coll_base_allreduce_intra_basic_linear(ALLREDUCE_ARGS);
int ompi_coll_base_allreduce_intra_redscat_allgather(ALLREDUCE_ARGS);
int ompi_coll_base_allreduce_intra_allgatherreduce(ALLREDUCE_ARGS);
int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS);

/* AlltoAll */
int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS);
Expand Down
2 changes: 1 addition & 1 deletion ompi/mca/coll/tuned/coll_tuned_allreduce_decision.c
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ int ompi_coll_tuned_allreduce_intra_do_this(const void *sbuf, void *rbuf, int co
case (6):
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, module);
case (7):
return ompi_coll_base_allreduce_intra_allgatherreduce(sbuf, rbuf, count, dtype, op, comm, module);
return ompi_coll_base_allreduce_intra_allgather_reduce(sbuf, rbuf, count, dtype, op, comm, module);
} /* switch */
OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:allreduce_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
algorithm, ompi_coll_tuned_forced_max_algorithms[ALLREDUCE]));
Expand Down

0 comments on commit 6260a4e

Please sign in to comment.