-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds distributed row gatherer #1589
base: neighborhood-communicator
Are you sure you want to change the base?
Changes from all commits
32a35a4
88629fc
84dd51b
16830d8
39539a2
d87725f
3048f5e
08c1f4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#include "ginkgo/core/distributed/row_gatherer.hpp" | ||
|
||
#include <ginkgo/core/base/dense_cache.hpp> | ||
#include <ginkgo/core/base/precision_dispatch.hpp> | ||
#include <ginkgo/core/distributed/dense_communicator.hpp> | ||
#include <ginkgo/core/distributed/neighborhood_communicator.hpp> | ||
#include <ginkgo/core/matrix/dense.hpp> | ||
|
||
#include "core/base/dispatch_helper.hpp" | ||
|
||
namespace gko { | ||
namespace experimental { | ||
namespace distributed { | ||
|
||
|
||
#if GINKGO_HAVE_OPENMPI_PRE_4_1_X | ||
using DefaultCollComm = mpi::DenseCommunicator; | ||
#else | ||
using DefaultCollComm = mpi::NeighborhoodCommunicator; | ||
#endif | ||
|
||
|
||
template <typename LocalIndexType> | ||
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* b, LinOp* x) const | ||
{ | ||
apply_async(b, x).wait(); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* alpha, const LinOp* b, | ||
const LinOp* beta, LinOp* x) const | ||
GKO_NOT_IMPLEMENTED; | ||
|
||
|
||
template <typename LocalIndexType> | ||
mpi::request RowGatherer<LocalIndexType>::apply_async(ptr_param<const LinOp> b, | ||
ptr_param<LinOp> x) const | ||
{ | ||
int is_inactive; | ||
MPI_Status status; | ||
GKO_ASSERT_NO_MPI_ERRORS( | ||
MPI_Request_get_status(req_listener_, &is_inactive, &status)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we maybe move this MPI function into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That doesn't really work here, since this function would be a member function of |
||
// This is untestable. Some processes might complete the previous request | ||
// while others don't, so it's impossible to create a predictable behavior | ||
// for a test. | ||
GKO_THROW_IF_INVALID(is_inactive, | ||
"Tried to call RowGatherer::apply_async while there " | ||
"is already an active communication. Please use the " | ||
"overload with a workspace to handle multiple " | ||
"connections."); | ||
|
||
auto req = apply_async(b, x, send_workspace_); | ||
req_listener_ = *req.get(); | ||
return req; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
mpi::request RowGatherer<LocalIndexType>::apply_async( | ||
ptr_param<const LinOp> b, ptr_param<LinOp> x, array<char>& workspace) const | ||
{ | ||
mpi::request req; | ||
|
||
// dispatch global vector | ||
run<Vector, double, float, std::complex<double>, std::complex<float>>( | ||
b.get(), [&](const auto* b_global) { | ||
using ValueType = | ||
typename std::decay_t<decltype(*b_global)>::value_type; | ||
// dispatch local vector with the same precision as the global | ||
// vector | ||
::gko::precision_dispatch<ValueType>( | ||
[&](auto* x_local) { | ||
auto exec = this->get_executor(); | ||
|
||
auto use_host_buffer = mpi::requires_host_buffer( | ||
exec, coll_comm_->get_base_communicator()); | ||
auto mpi_exec = use_host_buffer ? exec->get_master() : exec; | ||
|
||
GKO_THROW_IF_INVALID( | ||
!use_host_buffer || mpi_exec->memory_accessible( | ||
x_local->get_executor()), | ||
"The receive buffer uses device memory, but MPI " | ||
"support of device memory is not available or host " | ||
"buffer were explicitly requested. Please provide a " | ||
"host buffer or enable MPI support for device memory."); | ||
|
||
auto b_local = b_global->get_local_vector(); | ||
|
||
dim<2> send_size(coll_comm_->get_send_size(), | ||
b_local->get_size()[1]); | ||
auto send_size_in_bytes = | ||
sizeof(ValueType) * send_size[0] * send_size[1]; | ||
workspace.set_executor(mpi_exec); | ||
if (send_size_in_bytes > workspace.get_size()) { | ||
workspace.resize_and_reset(sizeof(ValueType) * | ||
send_size[0] * send_size[1]); | ||
} | ||
Comment on lines
+98
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. combining them to assign the workspace directly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combine how? Do you mean like
|
||
auto send_buffer = matrix::Dense<ValueType>::create( | ||
mpi_exec, send_size, | ||
make_array_view( | ||
mpi_exec, send_size[0] * send_size[1], | ||
reinterpret_cast<ValueType*>(workspace.get_data())), | ||
send_size[1]); | ||
b_local->row_gather(&send_idxs_, send_buffer); | ||
|
||
auto recv_ptr = x_local->get_values(); | ||
auto send_ptr = send_buffer->get_values(); | ||
|
||
b_local->get_executor()->synchronize(); | ||
mpi::contiguous_type type( | ||
b_local->get_size()[1], | ||
mpi::type_impl<ValueType>::get_type()); | ||
req = coll_comm_->i_all_to_all_v( | ||
mpi_exec, send_ptr, type.get(), recv_ptr, type.get()); | ||
Comment on lines
+118
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. send_buffer might be on the host but the recv_ptr(x_local) might be on the device There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a check above to ensure that the memory space of the recv buffer is accessible from the mpi executor. So if GPU aware MPI is used, it should work (even if send buffer is on the host and recv buffer in the device or vice versa). Otherwise an exception will be thrown. |
||
}, | ||
x.get()); | ||
}); | ||
return req; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
std::shared_ptr<const mpi::CollectiveCommunicator> | ||
RowGatherer<LocalIndexType>::get_collective_communicator() const | ||
{ | ||
return coll_comm_; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
template <typename GlobalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer( | ||
std::shared_ptr<const Executor> exec, | ||
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm, | ||
const index_map<LocalIndexType, GlobalIndexType>& imap) | ||
: EnableLinOp<RowGatherer>( | ||
exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}), | ||
DistributedBase(coll_comm->get_base_communicator()), | ||
coll_comm_(std::move(coll_comm)), | ||
send_idxs_(exec), | ||
send_workspace_(exec), | ||
req_listener_(MPI_REQUEST_NULL) | ||
{ | ||
// check that the coll_comm_ and imap have the same recv size | ||
// the same check for the send size is not possible, since the | ||
// imap doesn't store send indices | ||
GKO_THROW_IF_INVALID( | ||
coll_comm_->get_recv_size() == imap.get_non_local_size(), | ||
"The collective communicator doesn't match the index map."); | ||
|
||
auto comm = coll_comm_->get_base_communicator(); | ||
auto inverse_comm = coll_comm_->create_inverse(); | ||
|
||
send_idxs_.resize_and_reset(coll_comm_->get_send_size()); | ||
inverse_comm | ||
->i_all_to_all_v(exec, | ||
imap.get_remote_local_idxs().get_const_flat_data(), | ||
send_idxs_.get_data()) | ||
.wait(); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
const LocalIndexType* RowGatherer<LocalIndexType>::get_const_row_idxs() const | ||
{ | ||
return send_idxs_.get_const_data(); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(std::shared_ptr<const Executor> exec, | ||
mpi::communicator comm) | ||
: EnableLinOp<RowGatherer>(exec), | ||
DistributedBase(comm), | ||
coll_comm_(std::make_shared<DefaultCollComm>(comm)), | ||
send_idxs_(exec), | ||
send_workspace_(exec), | ||
req_listener_(MPI_REQUEST_NULL) | ||
{} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept | ||
: EnableLinOp<RowGatherer>(o.get_executor()), | ||
DistributedBase(o.get_communicator()), | ||
send_idxs_(o.get_executor()), | ||
send_workspace_(o.get_executor()), | ||
req_listener_(MPI_REQUEST_NULL) | ||
{ | ||
*this = std::move(o); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=( | ||
const RowGatherer& o) | ||
{ | ||
if (this != &o) { | ||
this->set_size(o.get_size()); | ||
coll_comm_ = o.coll_comm_; | ||
send_idxs_ = o.send_idxs_; | ||
} | ||
return *this; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=( | ||
RowGatherer&& o) | ||
{ | ||
if (this != &o) { | ||
this->set_size(o.get_size()); | ||
o.set_size({}); | ||
coll_comm_ = std::exchange( | ||
o.coll_comm_, | ||
std::make_shared<DefaultCollComm>(o.get_communicator())); | ||
send_idxs_ = std::move(o.send_idxs_); | ||
send_workspace_ = std::move(o.send_workspace_); | ||
req_listener_ = std::exchange(o.req_listener_, MPI_REQUEST_NULL); | ||
} | ||
return *this; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(const RowGatherer& o) | ||
: EnableLinOp<RowGatherer>(o.get_executor()), | ||
DistributedBase(o.get_communicator()), | ||
send_idxs_(o.get_executor()) | ||
{ | ||
*this = o; | ||
} | ||
|
||
|
||
#define GKO_DECLARE_ROW_GATHERER(_itype) class RowGatherer<_itype> | ||
|
||
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_ROW_GATHERER); | ||
|
||
#undef GKO_DECLARE_ROW_GATHERER | ||
|
||
|
||
#define GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR(_ltype, _gtype) \ | ||
RowGatherer<_ltype>::RowGatherer( \ | ||
std::shared_ptr<const Executor> exec, \ | ||
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm, \ | ||
const index_map<_ltype, _gtype>& imap) | ||
|
||
GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( | ||
GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR); | ||
|
||
#undef GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR | ||
|
||
|
||
} // namespace distributed | ||
} // namespace experimental | ||
} // namespace gko |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,18 +264,15 @@ array<GlobalIndexType> Pgm<ValueType, IndexType>::communicate_non_local_agg( | |
{ | ||
auto exec = gko::as<LinOp>(matrix)->get_executor(); | ||
const auto comm = matrix->get_communicator(); | ||
auto send_sizes = matrix->send_sizes_; | ||
auto recv_sizes = matrix->recv_sizes_; | ||
auto send_offsets = matrix->send_offsets_; | ||
auto recv_offsets = matrix->recv_offsets_; | ||
auto gather_idxs = matrix->gather_idxs_; | ||
auto total_send_size = send_offsets.back(); | ||
auto total_recv_size = recv_offsets.back(); | ||
auto coll_comm = matrix->row_gatherer_->get_collective_communicator(); | ||
auto total_send_size = coll_comm->get_send_size(); | ||
auto total_recv_size = coll_comm->get_recv_size(); | ||
auto row_gatherer = matrix->row_gatherer_; | ||
|
||
array<IndexType> send_agg(exec, total_send_size); | ||
exec->run(pgm::make_gather_index( | ||
send_agg.get_size(), local_agg.get_const_data(), | ||
gather_idxs.get_const_data(), send_agg.get_data())); | ||
row_gatherer->get_const_row_idxs(), send_agg.get_data())); | ||
|
||
// temporary index map that contains no remote connections to map | ||
// local indices to global | ||
|
@@ -296,16 +293,16 @@ array<GlobalIndexType> Pgm<ValueType, IndexType>::communicate_non_local_agg( | |
seng_global_agg.get_data(), | ||
host_send_buffer.get_data()); | ||
} | ||
auto type = experimental::mpi::type_impl<GlobalIndexType>::get_type(); | ||
|
||
const auto send_ptr = use_host_buffer ? host_send_buffer.get_const_data() | ||
: seng_global_agg.get_const_data(); | ||
auto recv_ptr = use_host_buffer ? host_recv_buffer.get_data() | ||
: non_local_agg.get_data(); | ||
exec->synchronize(); | ||
comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr, | ||
send_sizes.data(), send_offsets.data(), type, recv_ptr, | ||
recv_sizes.data(), recv_offsets.data(), type); | ||
coll_comm | ||
->i_all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any difference between using all_to_all_v vs i_all_to_all_v? I assume all_to_all_v also update the interface There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
recv_ptr) | ||
.wait(); | ||
if (use_host_buffer) { | ||
exec->copy_from(exec->get_master(), total_recv_size, recv_ptr, | ||
non_local_agg.get_data()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can also implement the advanced apply by replacing
b_local->row_gather(idxs, buffer)
byb_local->row_gather(alpha, idxs, beta, buffer)
?