diff --git a/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp b/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp index 0f01c26f8e..9a83e719e1 100644 --- a/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp +++ b/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp @@ -220,13 +220,13 @@ struct SptrsvWrap { // copy-assignment KOKKOS_INLINE_FUNCTION - ArrayType& operator=(const ArrayType& rhs_) { - if (this != &rhs_) { - for (size_type i = 0; i < MAX_VEC_SIZE; ++i) { - m_data[i] = rhs_.m_data[i]; - } + ArrayType &operator=(const ArrayType &rhs_) { + if (this != &rhs_) { + for (size_type i = 0; i < MAX_VEC_SIZE; ++i) { + m_data[i] = rhs_.m_data[i]; } - return *this; + } + return *this; } KOKKOS_INLINE_FUNCTION @@ -2076,7 +2076,8 @@ struct SptrsvWrap { // -------------------------------- // Stream interfaces // -------------------------------- - template + template static void tri_solve_streams(const std::vector &execspace_v, const std::vector &thandle_v, const std::vector &row_map_v, const std::vector &entries_v, @@ -2085,8 +2086,8 @@ struct SptrsvWrap { // NOTE: Only support SEQLVLSCHD_RP and SEQLVLSCHD_TP1 at this moment using nodes_per_level_type = typename TriSolveHandle::hostspace_nnz_lno_view_t; using nodes_grouped_by_level_type = typename TriSolveHandle::nnz_lno_view_t; - using RPPointFunctor = FunctorTypeMacro(TriLvlSchedRPSolverFunctor, IsLower, false); - using TPPointFunctor = FunctorTypeMacro(TriLvlSchedTP1SolverFunctor, IsLower, false); + using RPFunctor = FunctorTypeMacro(TriLvlSchedRPSolverFunctor, IsLower, BlockEnabled); + using TPFunctor = FunctorTypeMacro(TriLvlSchedTP1SolverFunctor, IsLower, BlockEnabled); // Create vectors for handles' data in streams int nstreams = execspace_v.size(); @@ -2111,19 +2112,24 @@ struct SptrsvWrap { for (int i = 0; i < nstreams; i++) { // Only if stream i-th still has this level if (lvl < nlevels_v[i]) { - size_type lvl_nodes = hnodes_per_level_v[i](lvl); + const size_type lvl_nodes = hnodes_per_level_v[i](lvl); + const auto block_size = thandle_v[i]->get_block_size(); + const auto block_enabled = thandle_v[i]->is_block_enabled(); + KK_REQUIRE(block_enabled == BlockEnabled); if (lvl_nodes != 0) { if (thandle_v[i]->get_algorithm() == KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_RP) { Kokkos::parallel_for("parfor_fixed_lvl", range_policy(execspace_v[i], node_count_v[i], node_count_v[i] + lvl_nodes), - RPPointFunctor(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i], - nodes_grouped_by_level_v[i])); + RPFunctor(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i], + nodes_grouped_by_level_v[i], block_size)); } else if (thandle_v[i]->get_algorithm() == KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1) { int team_size = thandle_v[i]->get_team_size(); auto tp = team_size == -1 ? team_policy(execspace_v[i], lvl_nodes, Kokkos::AUTO) : team_policy(execspace_v[i], lvl_nodes, team_size); - TPPointFunctor tstf(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i], - nodes_grouped_by_level_v[i], node_count_v[i]); + TPFunctor tstf(row_map_v[i], entries_v[i], values_v[i], lhs_v[i], rhs_v[i], nodes_grouped_by_level_v[i], + node_count_v[i], block_size); + const int scratch_size = TPFunctor::SBlock::shmem_size(block_size, block_size); + tp = tp.set_scratch_size(0, Kokkos::PerTeam(scratch_size)); Kokkos::parallel_for("parfor_l_team", tp, tstf); } node_count_v[i] += lvl_nodes; diff --git a/sparse/impl/KokkosSparse_sptrsv_solve_spec.hpp b/sparse/impl/KokkosSparse_sptrsv_solve_spec.hpp index de9f31dbd6..b8f8a4ab1c 100644 --- a/sparse/impl/KokkosSparse_sptrsv_solve_spec.hpp +++ b/sparse/impl/KokkosSparse_sptrsv_solve_spec.hpp @@ -169,20 +169,33 @@ struct SPTRSV_SOLVEis_lower_tri() ? "KokkosSparse_sptrsv[lower]" : "KokkosSparse_sptrsv[upper]"); + const auto block_enabled = sptrsv_handle_v[0]->is_block_enabled(); if (sptrsv_handle_v[0]->is_lower_tri()) { for (int i = 0; i < static_cast(execspace_v.size()); i++) { if (sptrsv_handle_v[i]->is_symbolic_complete() == false) { Experimental::lower_tri_symbolic(execspace_v[i], *(sptrsv_handle_v[i]), row_map_v[i], entries_v[i]); } } - Sptrsv::template tri_solve_streams(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, b_v, x_v); + if (block_enabled) { + Sptrsv::template tri_solve_streams(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, + b_v, x_v); + } else { + Sptrsv::template tri_solve_streams(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, + b_v, x_v); + } } else { for (int i = 0; i < static_cast(execspace_v.size()); i++) { if (sptrsv_handle_v[i]->is_symbolic_complete() == false) { Experimental::upper_tri_symbolic(execspace_v[i], *(sptrsv_handle_v[i]), row_map_v[i], entries_v[i]); } } - Sptrsv::template tri_solve_streams(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, b_v, x_v); + if (block_enabled) { + Sptrsv::template tri_solve_streams(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, + b_v, x_v); + } else { + Sptrsv::template tri_solve_streams(execspace_v, sptrsv_handle_v, row_map_v, entries_v, values_v, + b_v, x_v); + } } Kokkos::Profiling::popRegion(); } diff --git a/sparse/unit_test/Test_Sparse_sptrsv.hpp b/sparse/unit_test/Test_Sparse_sptrsv.hpp index a3f3f4ab7a..8c9cbc3eac 100644 --- a/sparse/unit_test/Test_Sparse_sptrsv.hpp +++ b/sparse/unit_test/Test_Sparse_sptrsv.hpp @@ -158,6 +158,20 @@ struct SptrsvTest { return std::make_tuple(triMtx, lhs, rhs); } + template ::type * = nullptr> + static std::tuple create_crs_lhs_rhs_flex( + const std::vector> &fixture, const size_type) { + return create_crs_lhs_rhs(fixture); + } + + template ::type * = nullptr> + static std::tuple create_crs_lhs_rhs_flex( + const std::vector> &fixture, const size_type block_size) { + const auto [triMtx_crs, lhs, rhs] = create_crs_lhs_rhs(fixture); + Bsr triMtx(triMtx_crs, block_size); + return std::make_tuple(triMtx, lhs, rhs); + } + template static void basic_check(const SpMatrix &triMtx, const ValuesType &lhs, const ValuesType &rhs, const bool is_lower, const size_type block_size = 0) { @@ -496,10 +510,9 @@ struct SptrsvTest { } static void run_test_sptrsv_blocks_impl(const bool is_lower, const size_type block_size) { - auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture(); - const auto [triMtx_crs, lhs, rhs] = create_crs_lhs_rhs(fixture); + auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture(); + const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs_flex(fixture, block_size); - Bsr triMtx(triMtx_crs, block_size); basic_check(triMtx, lhs, rhs, is_lower, block_size); } @@ -510,6 +523,7 @@ struct SptrsvTest { } } + template static void run_test_sptrsv_streams(SPTRSVAlgorithm test_algo, int nstreams, const bool is_lower) { // Workaround for OpenMP: skip tests if concurrency < nstreams because of // not enough resource to partition @@ -525,9 +539,6 @@ struct SptrsvTest { #endif if (!run_streams_test) return; - const size_type nrows = 5; - const size_type nnz = 10; - auto instances = Kokkos::Experimental::partition_space(execution_space(), std::vector(nstreams, 1)); std::vector kh_v(nstreams); @@ -538,42 +549,42 @@ struct SptrsvTest { std::vector rhs_v(nstreams); std::vector lhs_v(nstreams); - auto fixture = is_lower ? get_5x5_lt_ones_fixture() : get_5x5_ut_ones_fixture(); - const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs(fixture); + const size_type block_size = BlockEnabled ? 2 : 0; + + auto fixture = is_lower ? get_6x6_lt_ones_fixture() : get_6x6_ut_ones_fixture(); + const auto [triMtx, lhs, rhs] = create_crs_lhs_rhs_flex(fixture, block_size); + + const size_type nrows = triMtx.numRows(); + const size_type nrows_point = triMtx.numPointRows(); + const size_type nnz = triMtx.nnz(); auto row_map = triMtx.graph.row_map; auto entries = triMtx.graph.entries; auto values = triMtx.values; + const size_type nvals = values.size(); + for (int i = 0; i < nstreams; i++) { // Allocate row_map_v[i] = RowMapType("row_map", nrows + 1); entries_v[i] = EntriesType("entries", nnz); - values_v[i] = ValuesType("values", nnz); + values_v[i] = ValuesType("values", nvals); // Copy Kokkos::deep_copy(row_map_v[i], row_map); Kokkos::deep_copy(entries_v[i], entries); Kokkos::deep_copy(values_v[i], values); - // Create known_lhs, generate rhs, then solve for lhs to compare to - // known_lhs - ValuesType known_lhs("known_lhs", nrows); - // Create known solution lhs set to all 1's - Kokkos::deep_copy(known_lhs, scalar_t(1)); - // Solution to find - lhs_v[i] = ValuesType("lhs", nrows); + lhs_v[i] = ValuesType("lhs", nrows_point); // A*known_lhs generates rhs: rhs is dense, use spmv - rhs_v[i] = ValuesType("rhs", nrows); - - KokkosSparse::spmv("N", scalar_t(1), triMtx, known_lhs, scalar_t(0), rhs_v[i]); - Kokkos::fence(); + rhs_v[i] = ValuesType("rhs", nrows_point); + Kokkos::deep_copy(rhs_v[i], rhs); // Create handle kh_v[i] = KernelHandle(); - kh_v[i].create_sptrsv_handle(test_algo, nrows, is_lower); + kh_v[i].create_sptrsv_handle(test_algo, nrows, is_lower, block_size); kh_ptr_v[i] = &kh_v[i]; // Symbolic phase @@ -615,8 +626,10 @@ void test_sptrsv_streams() { for (auto alg : algs) { for (int nstreams = 1; nstreams <= 4; ++nstreams) { - TestStruct::run_test_sptrsv_streams(alg, nstreams, true); - TestStruct::run_test_sptrsv_streams(alg, nstreams, false); + TestStruct::template run_test_sptrsv_streams(alg, nstreams, true); + TestStruct::template run_test_sptrsv_streams(alg, nstreams, false); + TestStruct::template run_test_sptrsv_streams(alg, nstreams, true); + TestStruct::template run_test_sptrsv_streams(alg, nstreams, false); } } }