Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable block support for sptrsv streams #2429

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ struct SptrsvWrap {

// copy-assignment
KOKKOS_INLINE_FUNCTION
ArrayType& operator=(const ArrayType& rhs_) {
if (this != &rhs_) {
for (size_type i = 0; i < MAX_VEC_SIZE; ++i) {
m_data[i] = rhs_.m_data[i];
}
ArrayType &operator=(const ArrayType &rhs_) {
if (this != &rhs_) {
for (size_type i = 0; i < MAX_VEC_SIZE; ++i) {
m_data[i] = rhs_.m_data[i];
}
return *this;
}
return *this;
}

KOKKOS_INLINE_FUNCTION
Expand Down Expand Up @@ -2076,7 +2076,8 @@ struct SptrsvWrap {
// --------------------------------
// Stream interfaces
// --------------------------------
template <bool IsLower, class RowMapType, class EntriesType, class ValuesType, class RHSType, class LHSType>
template <bool IsLower, bool BlockEnabled, class RowMapType, class EntriesType, class ValuesType, class RHSType,
class LHSType>
static void tri_solve_streams(const std::vector<execution_space> &execspace_v,
const std::vector<TriSolveHandle *> &thandle_v,
const std::vector<RowMapType> &row_map_v, const std::vector<EntriesType> &entries_v,
Expand All @@ -2085,8 +2086,8 @@ struct SptrsvWrap {
// NOTE: Only support SEQLVLSCHD_RP and SEQLVLSCHD_TP1 at this moment
using nodes_per_level_type = typename TriSolveHandle::hostspace_nnz_lno_view_t;
using nodes_grouped_by_level_type = typename TriSolveHandle::nnz_lno_view_t;
using RPPointFunctor = FunctorTypeMacro(TriLvlSchedRPSolverFunctor, IsLower, false);
using TPPointFunctor = FunctorTypeMacro(TriLvlSchedTP1SolverFunctor, IsLower, false);
using RPFunctor = FunctorTypeMacro(TriLvlSchedRPSolverFunctor, IsLower, BlockEnabled);
using TPFunctor = FunctorTypeMacro(TriLvlSchedTP1SolverFunctor, IsLower, BlockEnabled);

// Create vectors for handles' data in streams
int nstreams = execspace_v.size();
Expand All @@ -2111,19 +2112,24 @@ struct SptrsvWrap {
for (int i = 0; i < nstreams; i++) {
// Only if stream i-th still has this level
if (lvl < nlevels_v[i]) {
size_type lvl_nodes = hnodes_per_level_v[i](lvl);
const size_type lvl_nodes = hnodes_per_level_v[i](lvl);
const auto block_size = thandle_v[i]->get_block_size();
const auto block_enabled = thandle_v[i]->is_block_enabled();
KK_REQUIRE(block_enabled == BlockEnabled);
if (lvl_nodes != 0) {
if (thandle_v[i]->get_algorithm() == KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_RP) {
Kokkos::parallel_for("parfor_fixed_lvl",
range_policy(execspace_v[i], node_count_v[i], node_count_v[i] + lvl_nodes),
RPPointFunctor(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i],
nodes_grouped_by_level_v[i]));
RPFunctor(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i],
nodes_grouped_by_level_v[i], block_size));
} else if (thandle_v[i]->get_algorithm() == KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1) {
int team_size = thandle_v[i]->get_team_size();
auto tp = team_size == -1 ? team_policy(execspace_v[i], lvl_nodes, Kokkos::AUTO)
: team_policy(execspace_v[i], lvl_nodes, team_size);
TPPointFunctor tstf(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i],
nodes_grouped_by_level_v[i], node_count_v[i]);
TPFunctor tstf(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i], nodes_grouped_by_level_v[i],
node_count_v[i], block_size);
const int scratch_size = TPFunctor::SBlock::shmem_size(block_size, block_size);
tp = tp.set_scratch_size(0, Kokkos::PerTeam(scratch_size));
vqd8a marked this conversation as resolved.
Show resolved Hide resolved
Kokkos::parallel_for("parfor_l_team", tp, tstf);
}
node_count_v[i] += lvl_nodes;
Expand Down
17 changes: 15 additions & 2 deletions sparse/impl/KokkosSparse_sptrsv_solve_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,33 @@ struct SPTRSV_SOLVE<ExecutionSpace, KernelHandle, RowMapType, EntriesType, Value
}
Kokkos::Profiling::pushRegion(sptrsv_handle_v[0]->is_lower_tri() ? "KokkosSparse_sptrsv[lower]"
: "KokkosSparse_sptrsv[upper]");
const auto block_enabled = sptrsv_handle_v[0]->is_block_enabled();
if (sptrsv_handle_v[0]->is_lower_tri()) {
for (int i = 0; i < static_cast<int>(execspace_v.size()); i++) {
if (sptrsv_handle_v[i]->is_symbolic_complete() == false) {
Experimental::lower_tri_symbolic(execspace_v[i], *(sptrsv_handle_v[i]), row_map_v[i], entries_v[i]);
}
}
Sptrsv::template tri_solve_streams<true>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, b_v, x_v);
if (block_enabled) {
Sptrsv::template tri_solve_streams<true, true>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v,
b_v, x_v);
} else {
Sptrsv::template tri_solve_streams<true, false>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v,
b_v, x_v);
}
} else {
for (int i = 0; i < static_cast<int>(execspace_v.size()); i++) {
if (sptrsv_handle_v[i]->is_symbolic_complete() == false) {
Experimental::upper_tri_symbolic(execspace_v[i], *(sptrsv_handle_v[i]), row_map_v[i], entries_v[i]);
}
}
Sptrsv::template tri_solve_streams<false>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, b_v, x_v);
if (block_enabled) {
Sptrsv::template tri_solve_streams<false, true>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v,
b_v, x_v);
} else {
Sptrsv::template tri_solve_streams<false, false>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v,
b_v, x_v);
}
}
Kokkos::Profiling::popRegion();
}
Expand Down
59 changes: 36 additions & 23 deletions sparse/unit_test/Test_Sparse_sptrsv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ struct SptrsvTest {
return std::make_tuple(triMtx, lhs, rhs);
}

template <bool BlockEnabled, typename std::enable_if<!BlockEnabled>::type * = nullptr>
static std::tuple<Crs, ValuesType, ValuesType> create_crs_lhs_rhs_flex(
const std::vector<std::vector<scalar_t>> &fixture, const size_type) {
return create_crs_lhs_rhs(fixture);
}

template <bool BlockEnabled, typename std::enable_if<BlockEnabled>::type * = nullptr>
static std::tuple<Bsr, ValuesType, ValuesType> create_crs_lhs_rhs_flex(
const std::vector<std::vector<scalar_t>> &fixture, const size_type block_size) {
const auto [triMtx_crs, lhs, rhs] = create_crs_lhs_rhs(fixture);
Bsr triMtx(triMtx_crs, block_size);
return std::make_tuple(triMtx, lhs, rhs);
}

template <typename SpMatrix>
static void basic_check(const SpMatrix &triMtx, const ValuesType &lhs, const ValuesType &rhs, const bool is_lower,
const size_type block_size = 0) {
Expand Down Expand Up @@ -496,10 +510,9 @@ struct SptrsvTest {
}

static void run_test_sptrsv_blocks_impl(const bool is_lower, const size_type block_size) {
auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture();
const auto [triMtx_crs, lhs, rhs] = create_crs_lhs_rhs(fixture);
auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture();
const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs_flex<true>(fixture, block_size);

Bsr triMtx(triMtx_crs, block_size);
basic_check(triMtx, lhs, rhs, is_lower, block_size);
}

Expand All @@ -510,6 +523,7 @@ struct SptrsvTest {
}
}

template <bool BlockEnabled>
static void run_test_sptrsv_streams(SPTRSVAlgorithm test_algo, int nstreams, const bool is_lower) {
// Workaround for OpenMP: skip tests if concurrency < nstreams because of
// not enough resource to partition
Expand All @@ -525,9 +539,6 @@ struct SptrsvTest {
#endif
if (!run_streams_test) return;

const size_type nrows = 5;
const size_type nnz = 10;

auto instances = Kokkos::Experimental::partition_space(execution_space(), std::vector<int>(nstreams, 1));

std::vector<KernelHandle> kh_v(nstreams);
Expand All @@ -538,42 +549,42 @@ struct SptrsvTest {
std::vector<ValuesType> rhs_v(nstreams);
std::vector<ValuesType> lhs_v(nstreams);

auto fixture = is_lower ? get_5x5_lt_ones_fixture() : get_5x5_ut_ones_fixture();
const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs(fixture);
const size_type block_size = BlockEnabled ? 2 : 0;

auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture();
const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs_flex<BlockEnabled>(fixture, block_size);

const size_type nrows = triMtx.numRows();
const size_type nrows_point = triMtx.numPointRows();
const size_type nnz = triMtx.nnz();

auto row_map = triMtx.graph.row_map;
auto entries = triMtx.graph.entries;
auto values = triMtx.values;

const size_type nvals = values.size();

for (int i = 0; i < nstreams; i++) {
// Allocate
row_map_v[i] = RowMapType("row_map", nrows + 1);
entries_v[i] = EntriesType("entries", nnz);
values_v[i] = ValuesType("values", nnz);
values_v[i] = ValuesType("values", nvals);

// Copy
Kokkos::deep_copy(row_map_v[i], row_map);
Kokkos::deep_copy(entries_v[i], entries);
Kokkos::deep_copy(values_v[i], values);

// Create known_lhs, generate rhs, then solve for lhs to compare to
// known_lhs
ValuesType known_lhs("known_lhs", nrows);
// Create known solution lhs set to all 1's
Kokkos::deep_copy(known_lhs, scalar_t(1));

// Solution to find
lhs_v[i] = ValuesType("lhs", nrows);
lhs_v[i] = ValuesType("lhs", nrows_point);

// A*known_lhs generates rhs: rhs is dense, use spmv
rhs_v[i] = ValuesType("rhs", nrows);

KokkosSparse::spmv("N", scalar_t(1), triMtx, known_lhs, scalar_t(0), rhs_v[i]);
Kokkos::fence();
rhs_v[i] = ValuesType("rhs", nrows_point);
Kokkos::deep_copy(rhs_v[i], rhs);

// Create handle
kh_v[i] = KernelHandle();
kh_v[i].create_sptrsv_handle(test_algo, nrows, is_lower);
kh_v[i].create_sptrsv_handle(test_algo, nrows, is_lower, block_size);
kh_ptr_v[i] = &kh_v[i];

// Symbolic phase
Expand Down Expand Up @@ -615,8 +626,10 @@ void test_sptrsv_streams() {

for (auto alg : algs) {
for (int nstreams = 1; nstreams <= 4; ++nstreams) {
TestStruct::run_test_sptrsv_streams(alg, nstreams, true);
TestStruct::run_test_sptrsv_streams(alg, nstreams, false);
TestStruct::template run_test_sptrsv_streams<false>(alg, nstreams, true);
TestStruct::template run_test_sptrsv_streams<false>(alg, nstreams, false);
TestStruct::template run_test_sptrsv_streams<true>(alg, nstreams, true);
TestStruct::template run_test_sptrsv_streams<true>(alg, nstreams, false);
}
}
}
Expand Down
Loading