Skip to content

Commit

Permalink
[dist-mat] use row-gatherer
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Apr 19, 2024
1 parent 768ae07 commit 79de4c3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 136 deletions.
143 changes: 25 additions & 118 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/distributed/neighborhood_communicator.hpp>
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/matrix/csr.hpp>

Expand Down Expand Up @@ -45,14 +46,10 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
: EnableDistributedLinOp<
Matrix<value_type, local_index_type, global_index_type>>{exec},
DistributedBase{comm},
send_offsets_(comm.size() + 1),
send_sizes_(comm.size()),
recv_offsets_(comm.size() + 1),
recv_sizes_(comm.size()),
gather_idxs_{exec},
one_scalar_{},
local_mtx_{local_matrix_template->clone(exec)},
non_local_mtx_{non_local_matrix_template->clone(exec)},
row_gatherer_{RowGatherer<LocalIndexType>::create(exec, comm)},
imap_{exec}
{
GKO_ASSERT(
Expand Down Expand Up @@ -106,11 +103,7 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
result->get_communicator().size());
result->local_mtx_->copy_from(this->local_mtx_);
result->non_local_mtx_->copy_from(this->non_local_mtx_);
result->gather_idxs_ = this->gather_idxs_;
result->send_offsets_ = this->send_offsets_;
result->recv_offsets_ = this->recv_offsets_;
result->recv_sizes_ = this->recv_sizes_;
result->send_sizes_ = this->send_sizes_;
result->row_gatherer_->copy_from(this->row_gatherer_);
result->set_size(this->get_size());
}

Expand All @@ -124,11 +117,7 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
result->get_communicator().size());
result->local_mtx_->move_from(this->local_mtx_);
result->non_local_mtx_->move_from(this->non_local_mtx_);
result->gather_idxs_ = std::move(this->gather_idxs_);
result->send_offsets_ = std::move(this->send_offsets_);
result->recv_offsets_ = std::move(this->recv_offsets_);
result->recv_sizes_ = std::move(this->recv_sizes_);
result->send_sizes_ = std::move(this->send_sizes_);
result->row_gatherer_->move_from(this->row_gatherer_);
result->set_size(this->get_size());
this->set_size({});
}
Expand All @@ -152,7 +141,6 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
auto local_part = comm.rank();

// set up LinOp sizes
auto num_parts = static_cast<size_type>(row_partition->get_num_parts());
auto global_num_rows = row_partition->get_size();
auto global_num_cols = col_partition->get_size();
dim<2> global_dim{global_num_rows, global_num_cols};
Expand All @@ -176,11 +164,11 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
local_row_idxs, local_col_idxs, local_values, non_local_row_idxs,
global_non_local_col_idxs, non_local_values));

auto imap = index_map<local_index_type, global_index_type>(
imap_ = index_map<local_index_type, global_index_type>(
exec, col_partition, comm.rank(), global_non_local_col_idxs);

auto non_local_col_idxs =
imap.get_local(global_non_local_col_idxs, index_space::non_local);
imap_.get_local(global_non_local_col_idxs, index_space::non_local);

// read the local matrix data
const auto num_local_rows =
Expand All @@ -193,48 +181,19 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
device_matrix_data<value_type, local_index_type> non_local_data{
exec,
dim<2>{num_local_rows,
imap.get_remote_global_idxs().get_flat().get_size()},
imap_.get_remote_global_idxs().get_flat().get_size()},
std::move(non_local_row_idxs), std::move(non_local_col_idxs),
std::move(non_local_values)};
as<ReadableFromMatrixData<ValueType, LocalIndexType>>(this->local_mtx_)
->read(std::move(local_data));
as<ReadableFromMatrixData<ValueType, LocalIndexType>>(this->non_local_mtx_)
->read(std::move(non_local_data));

// exchange step 1: determine recv_sizes, send_sizes, send_offsets
auto host_recv_targets =
make_temporary_clone(exec->get_master(), &imap.get_remote_target_ids());
std::fill(recv_sizes_.begin(), recv_sizes_.end(), 0);
for (size_type i = 0; i < host_recv_targets->get_size(); ++i) {
recv_sizes_[host_recv_targets->get_const_data()[i]] =
imap.get_remote_global_idxs()[i].get_size();
}
std::partial_sum(recv_sizes_.begin(), recv_sizes_.end(),
recv_offsets_.begin() + 1);
comm.all_to_all(exec, recv_sizes_.data(), 1, send_sizes_.data(), 1);
std::partial_sum(send_sizes_.begin(), send_sizes_.end(),
send_offsets_.begin() + 1);
send_offsets_[0] = 0;
recv_offsets_[0] = 0;

// exchange step 2: exchange gather_idxs from receivers to senders
auto recv_gather_idxs = imap.get_remote_local_idxs().get_flat();
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_gather_idxs.set_executor(exec->get_master());
gather_idxs_.clear();
gather_idxs_.set_executor(exec->get_master());
}
gather_idxs_.resize_and_reset(send_offsets_.back());
comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec,
recv_gather_idxs.get_const_data(), recv_sizes_.data(),
recv_offsets_.data(), gather_idxs_.get_data(),
send_sizes_.data(), send_offsets_.data());
if (use_host_buffer) {
gather_idxs_.set_executor(exec);
}
row_gatherer_ = RowGatherer<local_index_type>::create(
exec, std::make_shared<mpi::neighborhood_communicator>(comm, imap_),
imap_);

return imap;
return imap_;
}


Expand Down Expand Up @@ -279,50 +238,6 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
mpi::request Matrix<ValueType, LocalIndexType, GlobalIndexType>::communicate(
const local_vector_type* local_b) const
{
auto exec = this->get_executor();
const auto comm = this->get_communicator();
auto num_cols = local_b->get_size()[1];
auto send_size = send_offsets_.back();
auto recv_size = recv_offsets_.back();
auto send_dim = dim<2>{static_cast<size_type>(send_size), num_cols};
auto recv_dim = dim<2>{static_cast<size_type>(recv_size), num_cols};
recv_buffer_.init(exec, recv_dim);
send_buffer_.init(exec, send_dim);

local_b->row_gather(&gather_idxs_, send_buffer_.get());

auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
host_recv_buffer_.init(exec->get_master(), recv_dim);
host_send_buffer_.init(exec->get_master(), send_dim);
host_send_buffer_->copy_from(send_buffer_.get());
}

mpi::contiguous_type type(num_cols, mpi::type_impl<ValueType>::get_type());
auto send_ptr = use_host_buffer ? host_send_buffer_->get_const_values()
: send_buffer_->get_const_values();
auto recv_ptr = use_host_buffer ? host_recv_buffer_->get_values()
: recv_buffer_->get_values();
exec->synchronize();
#ifdef GINKGO_FORCE_SPMV_BLOCKING_COMM
comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr,
send_sizes_.data(), send_offsets_.data(), type.get(),
recv_ptr, recv_sizes_.data(), recv_offsets_.data(),
type.get());
return {};
#else
return comm.i_all_to_all_v(
use_host_buffer ? exec->get_master() : exec, send_ptr,
send_sizes_.data(), send_offsets_.data(), type.get(), recv_ptr,
recv_sizes_.data(), recv_offsets_.data(), type.get());
#endif
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
const LinOp* b, LinOp* x) const
Expand All @@ -338,16 +253,16 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
dense_x->get_local_values()),
dense_x->get_local_vector()->get_stride());

auto exec = this->get_executor();
auto comm = this->get_communicator();
auto req = this->communicate(dense_b->get_local_vector());
auto recv_dim =
dim<2>{imap_.get_non_local_size(), dense_b->get_size()[1]};
recv_buffer_.init(exec, recv_dim);
auto req =
this->row_gatherer_->apply_async(dense_b, recv_buffer_.get());
local_mtx_->apply(dense_b->get_local_vector(), local_x);
req.wait();

auto exec = this->get_executor();
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
non_local_mtx_->apply(one_scalar_.get(), recv_buffer_.get(),
one_scalar_.get(), local_x);
},
Expand All @@ -371,17 +286,17 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
dense_x->get_local_values()),
dense_x->get_local_vector()->get_stride());

auto exec = this->get_executor();
auto comm = this->get_communicator();
auto req = this->communicate(dense_b->get_local_vector());
auto recv_dim =
dim<2>{imap_.get_non_local_size(), dense_b->get_size()[1]};
recv_buffer_.init(exec, recv_dim);
auto req =
this->row_gatherer_->apply_async(dense_b, recv_buffer_.get());
local_mtx_->apply(local_alpha, dense_b->get_local_vector(),
local_beta, local_x);
req.wait();

auto exec = this->get_executor();
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
non_local_mtx_->apply(local_alpha, recv_buffer_.get(),
one_scalar_.get(), local_x);
},
Expand Down Expand Up @@ -423,11 +338,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(
this->set_size(other.get_size());
local_mtx_->copy_from(other.local_mtx_);
non_local_mtx_->copy_from(other.non_local_mtx_);
gather_idxs_ = other.gather_idxs_;
send_offsets_ = other.send_offsets_;
recv_offsets_ = other.recv_offsets_;
send_sizes_ = other.send_sizes_;
recv_sizes_ = other.recv_sizes_;
row_gatherer_->copy_from(other.row_gatherer_);
one_scalar_.init(this->get_executor(), dim<2>{1, 1});
one_scalar_->fill(one<value_type>());
}
Expand All @@ -446,11 +357,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(Matrix&& other)
other.set_size({});
local_mtx_->move_from(other.local_mtx_);
non_local_mtx_->move_from(other.non_local_mtx_);
gather_idxs_ = std::move(other.gather_idxs_);
send_offsets_ = std::move(other.send_offsets_);
recv_offsets_ = std::move(other.recv_offsets_);
send_sizes_ = std::move(other.send_sizes_);
recv_sizes_ = std::move(other.recv_sizes_);
row_gatherer_->move_from(other.row_gatherer_);
one_scalar_.init(this->get_executor(), dim<2>{1, 1});
one_scalar_->fill(one<value_type>());
}
Expand Down
32 changes: 14 additions & 18 deletions include/ginkgo/core/distributed/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ginkgo/core/distributed/base.hpp>
#include <ginkgo/core/distributed/index_map.hpp>
#include <ginkgo/core/distributed/lin_op.hpp>
#include <ginkgo/core/distributed/row_gatherer.hpp>


namespace gko {
Expand Down Expand Up @@ -358,6 +359,17 @@ class Matrix
return non_local_mtx_;
}

std::shared_ptr<const RowGatherer<local_index_type>> get_row_gatherer()
const
{
return row_gatherer_;
}

const index_map<local_index_type, global_index_type>& get_index_map() const
{
return imap_;
}

/**
* Copy constructs a Matrix.
*
Expand Down Expand Up @@ -530,31 +542,15 @@ class Matrix
ptr_param<const LinOp> local_matrix_template,
ptr_param<const LinOp> non_local_matrix_template);

/**
* Starts a non-blocking communication of the values of b that are shared
* with other processors.
*
* @param local_b The full local vector to be communicated. The subset of
* shared values is automatically extracted.
* @return MPI request for the non-blocking communication.
*/
mpi::request communicate(const local_vector_type* local_b) const;

void apply_impl(const LinOp* b, LinOp* x) const override;

void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
LinOp* x) const override;

private:
std::vector<comm_index_type> send_offsets_;
std::vector<comm_index_type> send_sizes_;
std::vector<comm_index_type> recv_offsets_;
std::vector<comm_index_type> recv_sizes_;
array<local_index_type> gather_idxs_;
std::shared_ptr<RowGatherer<LocalIndexType>> row_gatherer_;
index_map<local_index_type, global_index_type> imap_;
gko::detail::DenseCache<value_type> one_scalar_;
gko::detail::DenseCache<value_type> host_send_buffer_;
gko::detail::DenseCache<value_type> host_recv_buffer_;
gko::detail::DenseCache<value_type> send_buffer_;
gko::detail::DenseCache<value_type> recv_buffer_;
std::shared_ptr<LinOp> local_mtx_;
std::shared_ptr<LinOp> non_local_mtx_;
Expand Down

0 comments on commit 79de4c3

Please sign in to comment.