Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds distributed row gatherer #1589

Open
wants to merge 8 commits into
base: neighborhood-communicator
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ if(GINKGO_BUILD_MPI)
distributed/matrix.cpp
distributed/neighborhood_communicator.cpp
distributed/partition_helpers.cpp
distributed/row_gatherer.cpp
distributed/vector.cpp
distributed/preconditioner/schwarz.cpp)
endif()
Expand Down
234 changes: 66 additions & 168 deletions core/distributed/matrix.cpp

Large diffs are not rendered by default.

261 changes: 261 additions & 0 deletions core/distributed/row_gatherer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "ginkgo/core/distributed/row_gatherer.hpp"

#include <ginkgo/core/base/dense_cache.hpp>
#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/distributed/dense_communicator.hpp>
#include <ginkgo/core/distributed/neighborhood_communicator.hpp>
#include <ginkgo/core/matrix/dense.hpp>

#include "core/base/dispatch_helper.hpp"

namespace gko {
namespace experimental {
namespace distributed {


#if GINKGO_HAVE_OPENMPI_PRE_4_1_X
using DefaultCollComm = mpi::DenseCommunicator;
#else
using DefaultCollComm = mpi::NeighborhoodCommunicator;
#endif


template <typename LocalIndexType>
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* b, LinOp* x) const
{
apply_async(b, x).wait();
}


template <typename LocalIndexType>
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* alpha, const LinOp* b,
const LinOp* beta, LinOp* x) const
GKO_NOT_IMPLEMENTED;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can also implement the advanced apply by replacing b_local->row_gather(idxs, buffer) by b_local->row_gather(alpha, idxs, beta, buffer) ?



template <typename LocalIndexType>
mpi::request RowGatherer<LocalIndexType>::apply_async(ptr_param<const LinOp> b,
ptr_param<LinOp> x) const
{
int is_inactive;
MPI_Status status;
GKO_ASSERT_NO_MPI_ERRORS(
MPI_Request_get_status(req_listener_, &is_inactive, &status));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe move this MPI function into mpi.hpp and create a wrapper for it ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't really work here, since this function would be a member function of request, but I'm using a bare MPI_Request (and can't use request, because it will try to free the request in the destructor), so it would not be applicable.

// This is untestable. Some processes might complete the previous request
// while others don't, so it's impossible to create a predictable behavior
// for a test.
GKO_THROW_IF_INVALID(is_inactive,
"Tried to call RowGatherer::apply_async while there "
"is already an active communication. Please use the "
"overload with a workspace to handle multiple "
"connections.");

auto req = apply_async(b, x, send_workspace_);
req_listener_ = *req.get();
return req;
}


template <typename LocalIndexType>
mpi::request RowGatherer<LocalIndexType>::apply_async(
ptr_param<const LinOp> b, ptr_param<LinOp> x, array<char>& workspace) const
{
mpi::request req;

// dispatch global vector
run<Vector, double, float, std::complex<double>, std::complex<float>>(
b.get(), [&](const auto* b_global) {
using ValueType =
typename std::decay_t<decltype(*b_global)>::value_type;
// dispatch local vector with the same precision as the global
// vector
::gko::precision_dispatch<ValueType>(
[&](auto* x_local) {
auto exec = this->get_executor();

auto use_host_buffer = mpi::requires_host_buffer(
exec, coll_comm_->get_base_communicator());
auto mpi_exec = use_host_buffer ? exec->get_master() : exec;

GKO_THROW_IF_INVALID(
!use_host_buffer || mpi_exec->memory_accessible(
x_local->get_executor()),
"The receive buffer uses device memory, but MPI "
"support of device memory is not available or host "
"buffer were explicitly requested. Please provide a "
"host buffer or enable MPI support for device memory.");

auto b_local = b_global->get_local_vector();

dim<2> send_size(coll_comm_->get_send_size(),
b_local->get_size()[1]);
auto send_size_in_bytes =
sizeof(ValueType) * send_size[0] * send_size[1];
workspace.set_executor(mpi_exec);
if (send_size_in_bytes > workspace.get_size()) {
workspace.resize_and_reset(sizeof(ValueType) *
send_size[0] * send_size[1]);
}
Comment on lines +98 to +102
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combining them to assign the workspace directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine how? Do you mean like

workspace = array<char>(mpi_exec, sizeof(ValueType) * send_size[0] * send_size[1]);

auto send_buffer = matrix::Dense<ValueType>::create(
mpi_exec, send_size,
make_array_view(
mpi_exec, send_size[0] * send_size[1],
reinterpret_cast<ValueType*>(workspace.get_data())),
send_size[1]);
b_local->row_gather(&send_idxs_, send_buffer);

auto recv_ptr = x_local->get_values();
auto send_ptr = send_buffer->get_values();

b_local->get_executor()->synchronize();
mpi::contiguous_type type(
b_local->get_size()[1],
mpi::type_impl<ValueType>::get_type());
req = coll_comm_->i_all_to_all_v(
mpi_exec, send_ptr, type.get(), recv_ptr, type.get());
Comment on lines +118 to +119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

send_buffer might be on the host but the recv_ptr(x_local) might be on the device

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a check above to ensure that the memory space of the recv buffer is accessible from the mpi executor. So if GPU aware MPI is used, it should work (even if send buffer is on the host and recv buffer in the device or vice versa). Otherwise an exception will be thrown.

},
x.get());
});
return req;
}


template <typename LocalIndexType>
std::shared_ptr<const mpi::CollectiveCommunicator>
RowGatherer<LocalIndexType>::get_collective_communicator() const
{
return coll_comm_;
}


template <typename LocalIndexType>
template <typename GlobalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(
std::shared_ptr<const Executor> exec,
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm,
const index_map<LocalIndexType, GlobalIndexType>& imap)
: EnableLinOp<RowGatherer>(
exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}),
DistributedBase(coll_comm->get_base_communicator()),
coll_comm_(std::move(coll_comm)),
send_idxs_(exec),
send_workspace_(exec),
req_listener_(MPI_REQUEST_NULL)
{
// check that the coll_comm_ and imap have the same recv size
// the same check for the send size is not possible, since the
// imap doesn't store send indices
GKO_THROW_IF_INVALID(
coll_comm_->get_recv_size() == imap.get_non_local_size(),
"The collective communicator doesn't match the index map.");

auto comm = coll_comm_->get_base_communicator();
auto inverse_comm = coll_comm_->create_inverse();

send_idxs_.resize_and_reset(coll_comm_->get_send_size());
inverse_comm
->i_all_to_all_v(exec,
imap.get_remote_local_idxs().get_const_flat_data(),
send_idxs_.get_data())
.wait();
}


template <typename LocalIndexType>
const LocalIndexType* RowGatherer<LocalIndexType>::get_const_row_idxs() const
{
return send_idxs_.get_const_data();
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(std::shared_ptr<const Executor> exec,
mpi::communicator comm)
: EnableLinOp<RowGatherer>(exec),
DistributedBase(comm),
coll_comm_(std::make_shared<DefaultCollComm>(comm)),
send_idxs_(exec),
send_workspace_(exec),
req_listener_(MPI_REQUEST_NULL)
{}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept
: EnableLinOp<RowGatherer>(o.get_executor()),
DistributedBase(o.get_communicator()),
send_idxs_(o.get_executor()),
send_workspace_(o.get_executor()),
req_listener_(MPI_REQUEST_NULL)
{
*this = std::move(o);
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
const RowGatherer& o)
{
if (this != &o) {
this->set_size(o.get_size());
coll_comm_ = o.coll_comm_;
send_idxs_ = o.send_idxs_;
}
return *this;
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
RowGatherer&& o)
{
if (this != &o) {
this->set_size(o.get_size());
o.set_size({});
coll_comm_ = std::exchange(
o.coll_comm_,
std::make_shared<DefaultCollComm>(o.get_communicator()));
send_idxs_ = std::move(o.send_idxs_);
send_workspace_ = std::move(o.send_workspace_);
req_listener_ = std::exchange(o.req_listener_, MPI_REQUEST_NULL);
}
return *this;
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(const RowGatherer& o)
: EnableLinOp<RowGatherer>(o.get_executor()),
DistributedBase(o.get_communicator()),
send_idxs_(o.get_executor())
{
*this = o;
}


#define GKO_DECLARE_ROW_GATHERER(_itype) class RowGatherer<_itype>

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_ROW_GATHERER);

#undef GKO_DECLARE_ROW_GATHERER


#define GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR(_ltype, _gtype) \
RowGatherer<_ltype>::RowGatherer( \
std::shared_ptr<const Executor> exec, \
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm, \
const index_map<_ltype, _gtype>& imap)

GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR);

#undef GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR


} // namespace distributed
} // namespace experimental
} // namespace gko
21 changes: 9 additions & 12 deletions core/multigrid/pgm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,15 @@ array<GlobalIndexType> Pgm<ValueType, IndexType>::communicate_non_local_agg(
{
auto exec = gko::as<LinOp>(matrix)->get_executor();
const auto comm = matrix->get_communicator();
auto send_sizes = matrix->send_sizes_;
auto recv_sizes = matrix->recv_sizes_;
auto send_offsets = matrix->send_offsets_;
auto recv_offsets = matrix->recv_offsets_;
auto gather_idxs = matrix->gather_idxs_;
auto total_send_size = send_offsets.back();
auto total_recv_size = recv_offsets.back();
auto coll_comm = matrix->row_gatherer_->get_collective_communicator();
auto total_send_size = coll_comm->get_send_size();
auto total_recv_size = coll_comm->get_recv_size();
auto row_gatherer = matrix->row_gatherer_;

array<IndexType> send_agg(exec, total_send_size);
exec->run(pgm::make_gather_index(
send_agg.get_size(), local_agg.get_const_data(),
gather_idxs.get_const_data(), send_agg.get_data()));
row_gatherer->get_const_row_idxs(), send_agg.get_data()));

// temporary index map that contains no remote connections to map
// local indices to global
Expand All @@ -296,16 +293,16 @@ array<GlobalIndexType> Pgm<ValueType, IndexType>::communicate_non_local_agg(
seng_global_agg.get_data(),
host_send_buffer.get_data());
}
auto type = experimental::mpi::type_impl<GlobalIndexType>::get_type();

const auto send_ptr = use_host_buffer ? host_send_buffer.get_const_data()
: seng_global_agg.get_const_data();
auto recv_ptr = use_host_buffer ? host_recv_buffer.get_data()
: non_local_agg.get_data();
exec->synchronize();
comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr,
send_sizes.data(), send_offsets.data(), type, recv_ptr,
recv_sizes.data(), recv_offsets.data(), type);
coll_comm
->i_all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any difference between using all_to_all_v vs i_all_to_all_v? I assume all_to_all_v also update the interface

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_to_all_v is a blocking call, while i_all_to_all_v is non-blocking. Right now the collective_communicator only provides the non-blocking interface, since it is more general.

recv_ptr)
.wait();
if (use_host_buffer) {
exec->copy_from(exec->get_master(), total_recv_size, recv_ptr,
non_local_agg.get_data());
Expand Down
8 changes: 7 additions & 1 deletion core/test/gtest/ginkgo_mpi_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,13 @@ int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);

MPI_Init(&argc, &argv);
int provided_thread_support;
MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE,
&provided_thread_support);
if (provided_thread_support != MPI_THREAD_MULTIPLE) {
throw std::runtime_error(
"This test requires an thread compliant MPI implementation.");
}
MPI_Comm comm(MPI_COMM_WORLD);
int rank;
int size;
Expand Down
1 change: 1 addition & 0 deletions core/test/mpi/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
ginkgo_create_test(helpers MPI_SIZE 1)
ginkgo_create_test(matrix MPI_SIZE 1)
ginkgo_create_test(collective_communicator MPI_SIZE 6)
ginkgo_create_test(row_gatherer MPI_SIZE 6)
ginkgo_create_test(vector_cache MPI_SIZE 3)

add_subdirectory(preconditioner)
Expand Down
Loading
Loading