Skip to content

Commit

Permalink
Enable block support for sptrsv streams (#2429)
Browse files Browse the repository at this point in the history
Signed-off-by: James Foucar <[email protected]>
  • Loading branch information
jgfouca authored Nov 25, 2024
1 parent b3a4bdf commit 3acfeba
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 39 deletions.
34 changes: 20 additions & 14 deletions sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ struct SptrsvWrap {

// copy-assignment
KOKKOS_INLINE_FUNCTION
ArrayType& operator=(const ArrayType& rhs_) {
if (this != &rhs_) {
for (size_type i = 0; i < MAX_VEC_SIZE; ++i) {
m_data[i] = rhs_.m_data[i];
}
ArrayType &operator=(const ArrayType &rhs_) {
if (this != &rhs_) {
for (size_type i = 0; i < MAX_VEC_SIZE; ++i) {
m_data[i] = rhs_.m_data[i];
}
return *this;
}
return *this;
}

KOKKOS_INLINE_FUNCTION
Expand Down Expand Up @@ -2076,7 +2076,8 @@ struct SptrsvWrap {
// --------------------------------
// Stream interfaces
// --------------------------------
template <bool IsLower, class RowMapType, class EntriesType, class ValuesType, class RHSType, class LHSType>
template <bool IsLower, bool BlockEnabled, class RowMapType, class EntriesType, class ValuesType, class RHSType,
class LHSType>
static void tri_solve_streams(const std::vector<execution_space> &execspace_v,
const std::vector<TriSolveHandle *> &thandle_v,
const std::vector<RowMapType> &row_map_v, const std::vector<EntriesType> &entries_v,
Expand All @@ -2085,8 +2086,8 @@ struct SptrsvWrap {
// NOTE: Only support SEQLVLSCHD_RP and SEQLVLSCHD_TP1 at this moment
using nodes_per_level_type = typename TriSolveHandle::hostspace_nnz_lno_view_t;
using nodes_grouped_by_level_type = typename TriSolveHandle::nnz_lno_view_t;
using RPPointFunctor = FunctorTypeMacro(TriLvlSchedRPSolverFunctor, IsLower, false);
using TPPointFunctor = FunctorTypeMacro(TriLvlSchedTP1SolverFunctor, IsLower, false);
using RPFunctor = FunctorTypeMacro(TriLvlSchedRPSolverFunctor, IsLower, BlockEnabled);
using TPFunctor = FunctorTypeMacro(TriLvlSchedTP1SolverFunctor, IsLower, BlockEnabled);

// Create vectors for handles' data in streams
int nstreams = execspace_v.size();
Expand All @@ -2111,19 +2112,24 @@ struct SptrsvWrap {
for (int i = 0; i < nstreams; i++) {
// Only if stream i-th still has this level
if (lvl < nlevels_v[i]) {
size_type lvl_nodes = hnodes_per_level_v[i](lvl);
const size_type lvl_nodes = hnodes_per_level_v[i](lvl);
const auto block_size = thandle_v[i]->get_block_size();
const auto block_enabled = thandle_v[i]->is_block_enabled();
KK_REQUIRE(block_enabled == BlockEnabled);
if (lvl_nodes != 0) {
if (thandle_v[i]->get_algorithm() == KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_RP) {
Kokkos::parallel_for("parfor_fixed_lvl",
range_policy(execspace_v[i], node_count_v[i], node_count_v[i] + lvl_nodes),
RPPointFunctor(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i],
nodes_grouped_by_level_v[i]));
RPFunctor(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i],
nodes_grouped_by_level_v[i], block_size));
} else if (thandle_v[i]->get_algorithm() == KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1) {
int team_size = thandle_v[i]->get_team_size();
auto tp = team_size == -1 ? team_policy(execspace_v[i], lvl_nodes, Kokkos::AUTO)
: team_policy(execspace_v[i], lvl_nodes, team_size);
TPPointFunctor tstf(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i],
nodes_grouped_by_level_v[i], node_count_v[i]);
TPFunctor tstf(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i], nodes_grouped_by_level_v[i],
node_count_v[i], block_size);
const int scratch_size = TPFunctor::SBlock::shmem_size(block_size, block_size);
tp = tp.set_scratch_size(0, Kokkos::PerTeam(scratch_size));
Kokkos::parallel_for("parfor_l_team", tp, tstf);
}
node_count_v[i] += lvl_nodes;
Expand Down
17 changes: 15 additions & 2 deletions sparse/impl/KokkosSparse_sptrsv_solve_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,33 @@ struct SPTRSV_SOLVE<ExecutionSpace, KernelHandle, RowMapType, EntriesType, Value
}
Kokkos::Profiling::pushRegion(sptrsv_handle_v[0]->is_lower_tri() ? "KokkosSparse_sptrsv[lower]"
: "KokkosSparse_sptrsv[upper]");
const auto block_enabled = sptrsv_handle_v[0]->is_block_enabled();
if (sptrsv_handle_v[0]->is_lower_tri()) {
for (int i = 0; i < static_cast<int>(execspace_v.size()); i++) {
if (sptrsv_handle_v[i]->is_symbolic_complete() == false) {
Experimental::lower_tri_symbolic(execspace_v[i], *(sptrsv_handle_v[i]), row_map_v[i], entries_v[i]);
}
}
Sptrsv::template tri_solve_streams<true>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, b_v, x_v);
if (block_enabled) {
Sptrsv::template tri_solve_streams<true, true>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v,
b_v, x_v);
} else {
Sptrsv::template tri_solve_streams<true, false>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v,
b_v, x_v);
}
} else {
for (int i = 0; i < static_cast<int>(execspace_v.size()); i++) {
if (sptrsv_handle_v[i]->is_symbolic_complete() == false) {
Experimental::upper_tri_symbolic(execspace_v[i], *(sptrsv_handle_v[i]), row_map_v[i], entries_v[i]);
}
}
Sptrsv::template tri_solve_streams<false>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, b_v, x_v);
if (block_enabled) {
Sptrsv::template tri_solve_streams<false, true>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v,
b_v, x_v);
} else {
Sptrsv::template tri_solve_streams<false, false>(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v,
b_v, x_v);
}
}
Kokkos::Profiling::popRegion();
}
Expand Down
59 changes: 36 additions & 23 deletions sparse/unit_test/Test_Sparse_sptrsv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ struct SptrsvTest {
return std::make_tuple(triMtx, lhs, rhs);
}

template <bool BlockEnabled, typename std::enable_if<!BlockEnabled>::type * = nullptr>
static std::tuple<Crs, ValuesType, ValuesType> create_crs_lhs_rhs_flex(
const std::vector<std::vector<scalar_t>> &fixture, const size_type) {
return create_crs_lhs_rhs(fixture);
}

template <bool BlockEnabled, typename std::enable_if<BlockEnabled>::type * = nullptr>
static std::tuple<Bsr, ValuesType, ValuesType> create_crs_lhs_rhs_flex(
const std::vector<std::vector<scalar_t>> &fixture, const size_type block_size) {
const auto [triMtx_crs, lhs, rhs] = create_crs_lhs_rhs(fixture);
Bsr triMtx(triMtx_crs, block_size);
return std::make_tuple(triMtx, lhs, rhs);
}

template <typename SpMatrix>
static void basic_check(const SpMatrix &triMtx, const ValuesType &lhs, const ValuesType &rhs, const bool is_lower,
const size_type block_size = 0) {
Expand Down Expand Up @@ -496,10 +510,9 @@ struct SptrsvTest {
}

static void run_test_sptrsv_blocks_impl(const bool is_lower, const size_type block_size) {
auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture();
const auto [triMtx_crs, lhs, rhs] = create_crs_lhs_rhs(fixture);
auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture();
const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs_flex<true>(fixture, block_size);

Bsr triMtx(triMtx_crs, block_size);
basic_check(triMtx, lhs, rhs, is_lower, block_size);
}

Expand All @@ -510,6 +523,7 @@ struct SptrsvTest {
}
}

template <bool BlockEnabled>
static void run_test_sptrsv_streams(SPTRSVAlgorithm test_algo, int nstreams, const bool is_lower) {
// Workaround for OpenMP: skip tests if concurrency < nstreams because of
// not enough resource to partition
Expand All @@ -525,9 +539,6 @@ struct SptrsvTest {
#endif
if (!run_streams_test) return;

const size_type nrows = 5;
const size_type nnz = 10;

auto instances = Kokkos::Experimental::partition_space(execution_space(), std::vector<int>(nstreams, 1));

std::vector<KernelHandle> kh_v(nstreams);
Expand All @@ -538,42 +549,42 @@ struct SptrsvTest {
std::vector<ValuesType> rhs_v(nstreams);
std::vector<ValuesType> lhs_v(nstreams);

auto fixture = is_lower ? get_5x5_lt_ones_fixture() : get_5x5_ut_ones_fixture();
const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs(fixture);
const size_type block_size = BlockEnabled ? 2 : 0;

auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture();
const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs_flex<BlockEnabled>(fixture, block_size);

const size_type nrows = triMtx.numRows();
const size_type nrows_point = triMtx.numPointRows();
const size_type nnz = triMtx.nnz();

auto row_map = triMtx.graph.row_map;
auto entries = triMtx.graph.entries;
auto values = triMtx.values;

const size_type nvals = values.size();

for (int i = 0; i < nstreams; i++) {
// Allocate
row_map_v[i] = RowMapType("row_map", nrows + 1);
entries_v[i] = EntriesType("entries", nnz);
values_v[i] = ValuesType("values", nnz);
values_v[i] = ValuesType("values", nvals);

// Copy
Kokkos::deep_copy(row_map_v[i], row_map);
Kokkos::deep_copy(entries_v[i], entries);
Kokkos::deep_copy(values_v[i], values);

// Create known_lhs, generate rhs, then solve for lhs to compare to
// known_lhs
ValuesType known_lhs("known_lhs", nrows);
// Create known solution lhs set to all 1's
Kokkos::deep_copy(known_lhs, scalar_t(1));

// Solution to find
lhs_v[i] = ValuesType("lhs", nrows);
lhs_v[i] = ValuesType("lhs", nrows_point);

// A*known_lhs generates rhs: rhs is dense, use spmv
rhs_v[i] = ValuesType("rhs", nrows);

KokkosSparse::spmv("N", scalar_t(1), triMtx, known_lhs, scalar_t(0), rhs_v[i]);
Kokkos::fence();
rhs_v[i] = ValuesType("rhs", nrows_point);
Kokkos::deep_copy(rhs_v[i], rhs);

// Create handle
kh_v[i] = KernelHandle();
kh_v[i].create_sptrsv_handle(test_algo, nrows, is_lower);
kh_v[i].create_sptrsv_handle(test_algo, nrows, is_lower, block_size);
kh_ptr_v[i] = &kh_v[i];

// Symbolic phase
Expand Down Expand Up @@ -615,8 +626,10 @@ void test_sptrsv_streams() {

for (auto alg : algs) {
for (int nstreams = 1; nstreams <= 4; ++nstreams) {
TestStruct::run_test_sptrsv_streams(alg, nstreams, true);
TestStruct::run_test_sptrsv_streams(alg, nstreams, false);
TestStruct::template run_test_sptrsv_streams<false>(alg, nstreams, true);
TestStruct::template run_test_sptrsv_streams<false>(alg, nstreams, false);
TestStruct::template run_test_sptrsv_streams<true>(alg, nstreams, true);
TestStruct::template run_test_sptrsv_streams<true>(alg, nstreams, false);
}
}
}
Expand Down

0 comments on commit 3acfeba

Please sign in to comment.