Skip to content

Commit

Permalink
Sptrsv stream test fixes (#2444)
Browse files Browse the repository at this point in the history
SPTRSV_CUSPARSE algorithm is not supported for streams, so it
made no sense to add it to the list of tested algs inside
test_sptrsv_streams. The result was tri_solve_streams being
a no-op for this algorithm. Somehow, this was not caught until
the block algorithm was being used.

Also, some minor cleanup of alg enum handling in the sptrsv handle.
Use a switch statement with a default to catch unhandled enum vals.
print_algorithm should just use the alg string to avoid a duplicated
switch/ifelseif chain. StringToSPTRSVAlgorithm was not handling
several of the enum vals and also returning strings inconsistent
with the strings in return_algorithm_string. Grep revealed no one
using this function, so I removed it.

I added a check in tri_solve_streams to throw an error if an unsupported
alg is used.

Signed-off-by: James Foucar <[email protected]>
  • Loading branch information
jgfouca authored Dec 2, 2024
1 parent 372c88f commit 78f4efd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 60 deletions.
5 changes: 4 additions & 1 deletion sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2083,7 +2083,6 @@ struct SptrsvWrap {
const std::vector<RowMapType> &row_map_v, const std::vector<EntriesType> &entries_v,
const std::vector<ValuesType> &values_v, const std::vector<RHSType> &rhs_v,
std::vector<LHSType> &lhs_v) {
// NOTE: Only support SEQLVLSCHD_RP and SEQLVLSCHD_TP1 at this moment
using nodes_per_level_type = typename TriSolveHandle::hostspace_nnz_lno_view_t;
using nodes_grouped_by_level_type = typename TriSolveHandle::nnz_lno_view_t;
using RPFunctor = FunctorTypeMacro(TriLvlSchedRPSolverFunctor, IsLower, BlockEnabled);
Expand Down Expand Up @@ -2131,6 +2130,10 @@ struct SptrsvWrap {
const int scratch_size = TPFunctor::SBlock::shmem_size(block_size, block_size);
tp = tp.set_scratch_size(0, Kokkos::PerTeam(scratch_size));
Kokkos::parallel_for("parfor_l_team", tp, tstf);
} else {
// NOTE: Only support SEQLVLSCHD_RP and SEQLVLSCHD_TP1 at this moment
auto alg_name = thandle_v[i]->return_algorithm_string();
KK_REQUIRE_MSG(false, "Algorithm " << alg_name << " does not support streams");
}
node_count_v[i] += lvl_nodes;
} // end if (lvl_nodes != 0)
Expand Down
69 changes: 13 additions & 56 deletions sparse/src/KokkosSparse_sptrsv_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,69 +959,26 @@ class SPTRSVHandle {
int get_num_chain_entries() const { return this->num_chain_entries; }
void set_num_chain_entries(const int nce) { this->num_chain_entries = nce; }

void print_algorithm() {
if (algm == SPTRSVAlgorithm::SEQLVLSCHD_RP) std::cout << "SEQLVLSCHD_RP" << std::endl;
;

if (algm == SPTRSVAlgorithm::SEQLVLSCHD_TP1) std::cout << "SEQLVLSCHD_TP1" << std::endl;
;
/*
if ( algm == SPTRSVAlgorithm::SEQLVLSCHED_TP2 ) {
std::cout << "SEQLVLSCHED_TP2" << std::endl;;
std::cout << "WARNING: With CUDA this is currently only reliable with
int-int ordinal-offset pair" << std::endl;
}
*/
if (algm == SPTRSVAlgorithm::SEQLVLSCHD_TP1CHAIN) std::cout << "SEQLVLSCHD_TP1CHAIN" << std::endl;
;

if (algm == SPTRSVAlgorithm::SPTRSV_CUSPARSE) std::cout << "SPTRSV_CUSPARSE" << std::endl;
;

if (algm == SPTRSVAlgorithm::SUPERNODAL_NAIVE) std::cout << "SUPERNODAL_NAIVE" << std::endl;

if (algm == SPTRSVAlgorithm::SUPERNODAL_ETREE) std::cout << "SUPERNODAL_ETREE" << std::endl;

if (algm == SPTRSVAlgorithm::SUPERNODAL_DAG) std::cout << "SUPERNODAL_DAG" << std::endl;

if (algm == SPTRSVAlgorithm::SUPERNODAL_SPMV) std::cout << "SUPERNODAL_SPMV" << std::endl;

if (algm == SPTRSVAlgorithm::SUPERNODAL_SPMV_DAG) std::cout << "SUPERNODAL_SPMV_DAG" << std::endl;
}
inline void print_algorithm() { std::cout << return_algorithm_string() << std::endl; }

std::string return_algorithm_string() {
std::string ret_string;

if (algm == SPTRSVAlgorithm::SEQLVLSCHD_RP) ret_string = "SEQLVLSCHD_RP";

if (algm == SPTRSVAlgorithm::SEQLVLSCHD_TP1) ret_string = "SEQLVLSCHD_TP1";
/*
if ( algm == SPTRSVAlgorithm::SEQLVLSCHED_TP2 )
ret_string = "SEQLVLSCHED_TP2";
*/
if (algm == SPTRSVAlgorithm::SEQLVLSCHD_TP1CHAIN) ret_string = "SEQLVLSCHD_TP1CHAIN";

if (algm == SPTRSVAlgorithm::SPTRSV_CUSPARSE) ret_string = "SPTRSV_CUSPARSE";
switch (algm) {
case SPTRSVAlgorithm::SEQLVLSCHD_RP: ret_string = "SEQLVLSCHD_RP"; break;
case SPTRSVAlgorithm::SEQLVLSCHD_TP1: ret_string = "SEQLVLSCHD_TP1"; break;
case SPTRSVAlgorithm::SEQLVLSCHD_TP1CHAIN: ret_string = "SEQLVLSCHD_TP1CHAIN"; break;
case SPTRSVAlgorithm::SPTRSV_CUSPARSE: ret_string = "SPTRSV_CUSPARSE"; break;
case SPTRSVAlgorithm::SUPERNODAL_NAIVE: ret_string = "SUPERNODAL_NAIVE"; break;
case SPTRSVAlgorithm::SUPERNODAL_ETREE: ret_string = "SUPERNODAL_ETREE"; break;
case SPTRSVAlgorithm::SUPERNODAL_DAG: ret_string = "SUPERNODAL_DAG"; break;
case SPTRSVAlgorithm::SUPERNODAL_SPMV: ret_string = "SUPERNODAL_SPMV"; break;
case SPTRSVAlgorithm::SUPERNODAL_SPMV_DAG: ret_string = "SUPERNODAL_SPMV_DAG"; break;
default: KK_REQUIRE_MSG(false, "Unhandled sptrsv algorithm: " << static_cast<int>(algm));
}

return ret_string;
}

inline SPTRSVAlgorithm StringToSPTRSVAlgorithm(std::string &name) {
if (name == "SPTRSV_DEFAULT")
return SPTRSVAlgorithm::SEQLVLSCHD_RP;
else if (name == "SPTRSV_RANGEPOLICY")
return SPTRSVAlgorithm::SEQLVLSCHD_RP;
else if (name == "SPTRSV_TEAMPOLICY1")
return SPTRSVAlgorithm::SEQLVLSCHD_TP1;
/*else if(name=="SPTRSV_TEAMPOLICY2") return
* SPTRSVAlgorithm::SEQLVLSCHED_TP2;*/
else if (name == "SPTRSV_TEAMPOLICY1CHAIN")
return SPTRSVAlgorithm::SEQLVLSCHD_TP1CHAIN;
else if (name == "SPTRSV_CUSPARSE")
return SPTRSVAlgorithm::SPTRSV_CUSPARSE;
else
throw std::runtime_error("Invalid SPTRSVAlgorithm name");
}
};

} // namespace Experimental
Expand Down
3 changes: 0 additions & 3 deletions sparse/unit_test/Test_Sparse_sptrsv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,6 @@ template <typename scalar_t, typename lno_t, typename size_type, typename device
void test_sptrsv_streams() {
using TestStruct = Test::SptrsvTest<scalar_t, lno_t, size_type, device>;
std::vector<SPTRSVAlgorithm> algs = {SPTRSVAlgorithm::SEQLVLSCHD_RP, SPTRSVAlgorithm::SEQLVLSCHD_TP1};
if (TestStruct::do_cusparse()) {
algs.push_back(SPTRSVAlgorithm::SPTRSV_CUSPARSE);
}

for (auto alg : algs) {
for (int nstreams = 1; nstreams <= 4; ++nstreams) {
Expand Down

0 comments on commit 78f4efd

Please sign in to comment.