diff --git a/ompi/mca/coll/base/coll_base_allreduce.c b/ompi/mca/coll/base/coll_base_allreduce.c index 95468a5c885..30ab0a4f869 100644 --- a/ompi/mca/coll/base/coll_base_allreduce.c +++ b/ompi/mca/coll/base/coll_base_allreduce.c @@ -18,6 +18,8 @@ * Copyright (c) 2018 Siberian State University of Telecommunications * and Information Science. All rights reserved. * Copyright (c) 2022 Cisco Systems, Inc. All rights reserved. + * Copyright (c) Amazon.com, Inc. or its affiliates. + * All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -1245,4 +1247,116 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( return err; } +/** + * A greedy algorithm to exchange data among processes in the communicator via + * an allgather pattern, followed by a local reduction on each process. This + * avoids the round trip in a rooted communication pattern, e.g. reduce on the + * root and then broadcast to peers. + * + * This algorithm supports both commutative and non-commutative MPI operations. + * For non-commutative operations the reduction is applied to the data in the + * same rank order, e.g. rank 0, rank 1, ... rank N, on each process. + * + * This algorithm benefits inter-node allreduce over a high-latency network. + * Caution is needed on larger communicators(n) and data sizes(m), which will + * result in m*n^2 total traffic and potential network congestion. + */ +int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + char *send_buf = (void *) sbuf; + int comm_size = ompi_comm_size(comm); + int err = MPI_SUCCESS; + int rank = ompi_comm_rank(comm); + bool commutative = ompi_op_is_commute(op); + ompi_request_t **reqs; + + if (sbuf == MPI_IN_PLACE) { + send_buf = rbuf; + } + + /* Allocate a large-enough buffer to receive from everyone else */ + char *tmp_buf = NULL, *tmp_buf_raw = NULL, *tmp_recv = NULL; + ptrdiff_t lb, extent, dsize, gap = 0; + ompi_datatype_get_extent(dtype, &lb, &extent); + dsize = opal_datatype_span(&dtype->super, count * comm_size, &gap); + tmp_buf_raw = (char *) malloc(dsize); + if (NULL == tmp_buf_raw) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + if (commutative) { + ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf); + } + + tmp_buf = tmp_buf_raw - gap; + + /* Requests for send to AND receive from everyone else */ + int reqs_needed = (comm_size - 1) * 2; + reqs = ompi_coll_base_comm_get_reqs(module->base_data, reqs_needed); + + ptrdiff_t incr = extent * count; + tmp_recv = (char *) tmp_buf; + + /* Exchange data with peer processes */ + int req_index = 0, peer_rank = 0; + for (int i = 1; i < comm_size; ++i) { + peer_rank = (rank + i) % comm_size; + tmp_recv = tmp_buf + (peer_rank * incr); + err = MCA_PML_CALL(irecv(tmp_recv, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE, + comm, &reqs[req_index++])); + if (MPI_SUCCESS != err) { + goto err_hndl; + } + + err = MCA_PML_CALL(isend(send_buf, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE, + MCA_PML_BASE_SEND_STANDARD, comm, &reqs[req_index++])); + if (MPI_SUCCESS != err) { + goto err_hndl; + } + } + + err = ompi_request_wait_all(req_index, reqs, MPI_STATUSES_IGNORE); + + /* Prepare for local reduction */ + peer_rank = 0; + if (!commutative) { + /* For non-commutative operations, ensure the reduction always starts from Rank 0's data */ + memcpy(rbuf, 0 == rank ? send_buf : tmp_buf, incr); + peer_rank = 1; + } + + char *inbuf; + for (; peer_rank < comm_size; peer_rank++) { + inbuf = rank == peer_rank ? send_buf : tmp_buf + (peer_rank * incr); + ompi_op_reduce(op, (void *) inbuf, rbuf, count, dtype); + } + +err_hndl: + if (NULL != tmp_buf_raw) + free(tmp_buf_raw); + + if (NULL != reqs) { + if (MPI_ERR_IN_STATUS == err) { + for (int i = 0; i < reqs_needed; i++) { + if (MPI_REQUEST_NULL == reqs[i]) + continue; + if (MPI_ERR_PENDING == reqs[i]->req_status.MPI_ERROR) + continue; + if (MPI_SUCCESS != reqs[i]->req_status.MPI_ERROR) { + err = reqs[i]->req_status.MPI_ERROR; + break; + } + } + } + ompi_coll_base_free_reqs(reqs, reqs_needed); + } + + /* All done */ + return err; +} + /* copied function (with appropriate renaming) ends here */ diff --git a/ompi/mca/coll/base/coll_base_functions.h b/ompi/mca/coll/base/coll_base_functions.h index 32714445904..1c73d01d37e 100644 --- a/ompi/mca/coll/base/coll_base_functions.h +++ b/ompi/mca/coll/base/coll_base_functions.h @@ -210,6 +210,7 @@ int ompi_coll_base_allreduce_intra_ring(ALLREDUCE_ARGS); int ompi_coll_base_allreduce_intra_ring_segmented(ALLREDUCE_ARGS, uint32_t segsize); int ompi_coll_base_allreduce_intra_basic_linear(ALLREDUCE_ARGS); int ompi_coll_base_allreduce_intra_redscat_allgather(ALLREDUCE_ARGS); +int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS); /* AlltoAll */ int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS); diff --git a/ompi/mca/coll/tuned/coll_tuned_allreduce_decision.c b/ompi/mca/coll/tuned/coll_tuned_allreduce_decision.c index eabe6f17378..3711cdb8eb1 100644 --- a/ompi/mca/coll/tuned/coll_tuned_allreduce_decision.c +++ b/ompi/mca/coll/tuned/coll_tuned_allreduce_decision.c @@ -42,6 +42,7 @@ static const mca_base_var_enum_value_t allreduce_algorithms[] = { {4, "ring"}, {5, "segmented_ring"}, {6, "rabenseifner"}, + {7, "allgather_reduce"}, {0, NULL} }; @@ -146,6 +147,8 @@ int ompi_coll_tuned_allreduce_intra_do_this(const void *sbuf, void *rbuf, int co return ompi_coll_base_allreduce_intra_ring_segmented(sbuf, rbuf, count, dtype, op, comm, module, segsize); case (6): return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, module); + case (7): + return ompi_coll_base_allreduce_intra_allgather_reduce(sbuf, rbuf, count, dtype, op, comm, module); } /* switch */ OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:allreduce_intra_do_this attempt to select algorithm %d when only 0-%d is valid?", algorithm, ompi_coll_tuned_forced_max_algorithms[ALLREDUCE]));